diff --git a/.github/workflows/e2e-java-tracer.yaml b/.github/workflows/e2e-java-tracer.yaml index 7e92e9eee..6ed17ce90 100644 --- a/.github/workflows/e2e-java-tracer.yaml +++ b/.github/workflows/e2e-java-tracer.yaml @@ -3,17 +3,9 @@ name: E2E - Java Tracer on: pull_request: paths: - - 'codeflash/languages/java/**' - - 'codeflash/languages/base.py' - - 'codeflash/languages/registry.py' - - 'codeflash/tracer.py' - - 'codeflash/benchmarking/function_ranker.py' - - 'codeflash/discovery/functions_to_optimize.py' - - 'codeflash/optimization/**' - - 'codeflash/verification/**' + - 'codeflash/**' - 'codeflash-java-runtime/**' - - 'tests/test_languages/fixtures/java_tracer_e2e/**' - - 'tests/scripts/end_to_end_test_java_tracer.py' + - 'tests/**' - '.github/workflows/e2e-java-tracer.yaml' workflow_dispatch: diff --git a/code_to_optimize/java-gradle/codeflash.toml b/code_to_optimize/java-gradle/codeflash.toml deleted file mode 100644 index bf6e45279..000000000 --- a/code_to_optimize/java-gradle/codeflash.toml +++ /dev/null @@ -1,4 +0,0 @@ -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" -formatter-cmds = [] diff --git a/code_to_optimize/java/codeflash.toml b/code_to_optimize/java/codeflash.toml deleted file mode 100644 index 4016df28a..000000000 --- a/code_to_optimize/java/codeflash.toml +++ /dev/null @@ -1,6 +0,0 @@ -# Codeflash configuration for Java project - -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" -formatter-cmds = [] diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java b/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java index f4b9ec453..3a73038c1 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java @@ -12,20 +12,179 @@ public class ReplayHelper { - private final Connection db; + private final Connection traceDb; + + // Codeflash instrumentation state — read from environment variables once + private final String mode; // "behavior", "performance", or null + private final int loopIndex; + private final String testIteration; + private final String outputFile; // SQLite path for behavior capture + private final int innerIterations; // for performance looping + + // Behavior mode: lazily opened SQLite connection for writing results + private Connection behaviorDb; + private boolean behaviorDbInitialized; public ReplayHelper(String traceDbPath) { try { - this.db = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath); + this.traceDb = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath); } catch (SQLException e) { throw new RuntimeException("Failed to open trace database: " + traceDbPath, e); } + + // Read codeflash instrumentation env vars (set by the test runner) + this.mode = System.getenv("CODEFLASH_MODE"); + this.loopIndex = parseIntEnv("CODEFLASH_LOOP_INDEX", 1); + this.testIteration = getEnvOrDefault("CODEFLASH_TEST_ITERATION", "0"); + this.outputFile = System.getenv("CODEFLASH_OUTPUT_FILE"); + this.innerIterations = parseIntEnv("CODEFLASH_INNER_ITERATIONS", 10); } public void replay(String className, String methodName, String descriptor, int invocationIndex) throws Exception { - // Query the function_calls table for this method at the given index + // Deserialize args and resolve method (done once, outside timing) + Object[] allArgs = loadArgs(className, methodName, descriptor, invocationIndex); + Class targetClass = Class.forName(className); + + Type[] paramTypes = Type.getArgumentTypes(descriptor); + Class[] paramClasses = new Class[paramTypes.length]; + for (int i = 0; i < paramTypes.length; i++) { + paramClasses[i] = typeToClass(paramTypes[i]); + } + + Method method = targetClass.getDeclaredMethod(methodName, paramClasses); + method.setAccessible(true); + boolean isStatic = Modifier.isStatic(method.getModifiers()); + + Object instance = null; + if (!isStatic) { + try { + java.lang.reflect.Constructor ctor = targetClass.getDeclaredConstructor(); + ctor.setAccessible(true); + instance = ctor.newInstance(); + } catch (NoSuchMethodException e) { + instance = new org.objenesis.ObjenesisStd().newInstance(targetClass); + } + } + + // Get the calling test method name from the stack trace + String testMethodName = getCallingTestMethodName(); + // Module name = the test class that called us + String testClassName = getCallingTestClassName(); + + if ("behavior".equals(mode)) { + replayBehavior(method, instance, allArgs, className, methodName, testClassName, testMethodName); + } else if ("performance".equals(mode)) { + replayPerformance(method, instance, allArgs, className, methodName, testClassName, testMethodName); + } else { + // No codeflash mode — just invoke (trace-only or manual testing) + method.invoke(instance, allArgs); + } + } + + private void replayBehavior(Method method, Object instance, Object[] args, + String className, String methodName, + String testClassName, String testMethodName) throws Exception { + String invId = testIteration + "_" + testMethodName; + + // Print start marker (same format as behavior instrumentation) + System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName + + ":" + methodName + ":" + loopIndex + ":" + invId + "######$!"); + + long startNs = System.nanoTime(); + Object result; + try { + result = method.invoke(instance, args); + } catch (java.lang.reflect.InvocationTargetException e) { + throw (Exception) e.getCause(); + } + long durationNs = System.nanoTime() - startNs; + + // Print end marker + System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName + + ":" + methodName + ":" + loopIndex + ":" + invId + ":" + durationNs + "######!"); + + // Write return value to SQLite for correctness comparison + if (outputFile != null && !outputFile.isEmpty()) { + writeBehaviorResult(testClassName, testMethodName, methodName, invId, durationNs, result); + } + } + + private void replayPerformance(Method method, Object instance, Object[] args, + String className, String methodName, + String testClassName, String testMethodName) throws Exception { + // Performance mode: run inner loop for JIT warmup, print timing for each iteration + int maxInner = innerIterations; + for (int inner = 0; inner < maxInner; inner++) { + int loopId = (loopIndex - 1) * maxInner + inner; + String invId = testMethodName; + + // Print start marker + System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName + + ":" + methodName + ":" + loopId + ":" + invId + "######$!"); + + long startNs = System.nanoTime(); + try { + method.invoke(instance, args); + } catch (java.lang.reflect.InvocationTargetException e) { + // Swallow — performance mode doesn't check correctness + } + long durationNs = System.nanoTime() - startNs; + + // Print end marker + System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName + + ":" + methodName + ":" + loopId + ":" + invId + ":" + durationNs + "######!"); + } + } + + private void writeBehaviorResult(String testClassName, String testMethodName, + String functionName, String invId, + long durationNs, Object result) { + try { + ensureBehaviorDb(); + String sql = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + try (PreparedStatement ps = behaviorDb.prepareStatement(sql)) { + ps.setString(1, testClassName); // test_module_path + ps.setString(2, testClassName); // test_class_name + ps.setString(3, testMethodName); // test_function_name + ps.setString(4, functionName); // function_getting_tested + ps.setInt(5, loopIndex); // loop_index + ps.setString(6, invId); // iteration_id + ps.setLong(7, durationNs); // runtime + ps.setBytes(8, serializeResult(result)); // return_value + ps.setString(9, "function_call"); // verification_type + ps.executeUpdate(); + } + } catch (Exception e) { + System.err.println("ReplayHelper: SQLite behavior write error: " + e.getMessage()); + } + } + + private void ensureBehaviorDb() throws SQLException { + if (behaviorDbInitialized) return; + behaviorDbInitialized = true; + behaviorDb = DriverManager.getConnection("jdbc:sqlite:" + outputFile); + try (java.sql.Statement stmt = behaviorDb.createStatement()) { + stmt.execute("CREATE TABLE IF NOT EXISTS test_results (" + + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + + "runtime INTEGER, return_value BLOB, verification_type TEXT)"); + } + } + + private byte[] serializeResult(Object result) { + if (result == null) return null; + try { + return Serializer.serialize(result); + } catch (Exception e) { + // Fall back to String.valueOf if Kryo fails + return String.valueOf(result).getBytes(java.nio.charset.StandardCharsets.UTF_8); + } + } + + private Object[] loadArgs(String className, String methodName, String descriptor, int invocationIndex) + throws SQLException { byte[] argsBlob; - try (PreparedStatement stmt = db.prepareStatement( + try (PreparedStatement stmt = traceDb.prepareStatement( "SELECT args FROM function_calls " + "WHERE classname = ? AND function = ? AND descriptor = ? " + "ORDER BY time_ns LIMIT 1 OFFSET ?")) { @@ -43,46 +202,35 @@ public void replay(String className, String methodName, String descriptor, int i } } - // Deserialize args Object deserialized = Serializer.deserialize(argsBlob); if (!(deserialized instanceof Object[])) { throw new RuntimeException("Deserialized args is not Object[], got: " + (deserialized == null ? "null" : deserialized.getClass().getName())); } - Object[] allArgs = (Object[]) deserialized; - - // Load the target class - Class targetClass = Class.forName(className); + return (Object[]) deserialized; + } - // Parse descriptor to find parameter types - Type[] paramTypes = Type.getArgumentTypes(descriptor); - Class[] paramClasses = new Class[paramTypes.length]; - for (int i = 0; i < paramTypes.length; i++) { - paramClasses[i] = typeToClass(paramTypes[i]); + private static String getCallingTestMethodName() { + StackTraceElement[] stack = Thread.currentThread().getStackTrace(); + // Walk up: [0]=getStackTrace, [1]=this method, [2]=replay(), [3]=calling test method + for (int i = 3; i < stack.length; i++) { + String method = stack[i].getMethodName(); + if (method.startsWith("replay_")) { + return method; + } } + return stack.length > 3 ? stack[3].getMethodName() : "unknown"; + } - // Find the method - Method method = targetClass.getDeclaredMethod(methodName, paramClasses); - method.setAccessible(true); - - boolean isStatic = Modifier.isStatic(method.getModifiers()); - - if (isStatic) { - method.invoke(null, allArgs); - } else { - // Args contain only explicit parameters (no 'this'). - // Create a default instance via no-arg constructor or Kryo. - Object instance; - try { - java.lang.reflect.Constructor ctor = targetClass.getDeclaredConstructor(); - ctor.setAccessible(true); - instance = ctor.newInstance(); - } catch (NoSuchMethodException e) { - // Fall back to Objenesis instantiation (no constructor needed) - instance = new org.objenesis.ObjenesisStd().newInstance(targetClass); + private static String getCallingTestClassName() { + StackTraceElement[] stack = Thread.currentThread().getStackTrace(); + for (int i = 3; i < stack.length; i++) { + String cls = stack[i].getClassName(); + if (cls.contains("ReplayTest") || cls.contains("replay")) { + return cls; } - method.invoke(instance, allArgs); } + return stack.length > 3 ? stack[3].getClassName() : "unknown"; } private static Class typeToClass(Type type) throws ClassNotFoundException { @@ -106,11 +254,23 @@ private static Class typeToClass(Type type) throws ClassNotFoundException { } } + private static int parseIntEnv(String name, int defaultValue) { + String val = System.getenv(name); + if (val == null || val.isEmpty()) return defaultValue; + try { return Integer.parseInt(val); } catch (NumberFormatException e) { return defaultValue; } + } + + private static String getEnvOrDefault(String name, String defaultValue) { + String val = System.getenv(name); + return (val != null && !val.isEmpty()) ? val : defaultValue; + } + public void close() { - try { - if (db != null) db.close(); - } catch (SQLException e) { - System.err.println("Error closing ReplayHelper: " + e.getMessage()); + try { if (traceDb != null) traceDb.close(); } catch (SQLException e) { + System.err.println("Error closing ReplayHelper trace db: " + e.getMessage()); + } + try { if (behaviorDb != null) behaviorDb.close(); } catch (SQLException e) { + System.err.println("Error closing ReplayHelper behavior db: " + e.getMessage()); } } } diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java index 974c767a9..75c61de3a 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java @@ -22,11 +22,6 @@ public byte[] transform(ClassLoader loader, String className, return null; } - // Skip instrumentation if we're inside a recording call (e.g., during Kryo serialization) - if (TraceRecorder.isRecording()) { - return null; - } - // Skip internal JDK, framework, and synthetic classes if (className.startsWith("java/") || className.startsWith("javax/") diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index d76e60a11..c611f5cd9 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -185,11 +185,17 @@ def process_pyproject_config(args: Namespace) -> Namespace: args.ignore_paths = normalize_ignore_paths(args.ignore_paths, base_path=args.module_root) # If module-root is "." then all imports are relatives to it. # in this case, the ".." becomes outside project scope, causing issues with un-importable paths - args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path) + args.project_root = project_root_from_module_root(Path(args.module_root), pyproject_file_path) args.tests_root = Path(args.tests_root).resolve() if args.benchmarks_root: args.benchmarks_root = Path(args.benchmarks_root).resolve() args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path) + + if is_java_project and pyproject_file_path.is_dir(): + # For Java projects, pyproject_file_path IS the project root directory (not a file). + # Override project_root which may have resolved to a sub-module. + args.project_root = pyproject_file_path.resolve() + args.test_project_root = pyproject_file_path.resolve() if is_LSP_enabled(): args.all = None return args @@ -208,8 +214,6 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) return current.resolve() if (current / "build.gradle").exists() or (current / "build.gradle.kts").exists(): return current.resolve() - if (current / "codeflash.toml").exists(): - return current.resolve() current = current.parent return module_root.parent.resolve() @@ -370,7 +374,7 @@ def _build_parser() -> ArgumentParser: subparsers.add_parser("vscode-install", help="Install the Codeflash VSCode extension") subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow") - trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.") + trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.", add_help=False) trace_optimize.add_argument( "--max-function-count", diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index ef21ce051..1d0f13df5 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -12,8 +12,29 @@ ALL_CONFIG_FILES: dict[Path, dict[str, Path]] = {} +def _try_parse_java_build_config() -> tuple[dict[str, Any], Path] | None: + """Detect Java project from build files and parse config from pom.xml/gradle.properties. + + Returns (config_dict, project_root) if a Java project is found, None otherwise. + """ + dir_path = Path.cwd() + while dir_path != dir_path.parent: + if ( + (dir_path / "pom.xml").exists() + or (dir_path / "build.gradle").exists() + or (dir_path / "build.gradle.kts").exists() + ): + from codeflash.languages.java.build_tools import parse_java_project_config + + config = parse_java_project_config(dir_path) + if config is not None: + return config, dir_path + dir_path = dir_path.parent + return None + + def find_pyproject_toml(config_file: Path | None = None) -> Path: - # Find the pyproject.toml or codeflash.toml file on the root of the project + # Find the pyproject.toml file on the root of the project if config_file is not None: config_file = Path(config_file) @@ -29,21 +50,13 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path: # see if it was encountered before in search if cur_path in PYPROJECT_TOML_CACHE: return PYPROJECT_TOML_CACHE[cur_path] - # map current path to closest file - check both pyproject.toml and codeflash.toml while dir_path != dir_path.parent: - # First check pyproject.toml (Python projects) config_file = dir_path / "pyproject.toml" if config_file.exists(): PYPROJECT_TOML_CACHE[cur_path] = config_file return config_file - # Then check codeflash.toml (Java/other projects) - config_file = dir_path / "codeflash.toml" - if config_file.exists(): - PYPROJECT_TOML_CACHE[cur_path] = config_file - return config_file - # Search in parent directories dir_path = dir_path.parent - msg = f"Could not find pyproject.toml or codeflash.toml in the current directory {Path.cwd()} or any of the parent directories. Please create it by running `codeflash init`, or pass the path to the config file with the --config-file argument." + msg = f"Could not find pyproject.toml in the current directory {Path.cwd()} or any of the parent directories. Please create it by running `codeflash init`, or pass the path to the config file with the --config-file argument." raise ValueError(msg) from None @@ -90,33 +103,29 @@ def find_conftest_files(test_paths: list[Path]) -> list[Path]: return list(list_of_conftest_files) -# TODO for claude: There should be different functions to parse it per language, which should be chosen during runtime def parse_config_file( config_file_path: Path | None = None, override_formatter_check: bool = False ) -> tuple[dict[str, Any], Path]: + # Java projects: read config from pom.xml/gradle.properties (no standalone config file needed) + if config_file_path is None: + java_config = _try_parse_java_build_config() + if java_config is not None: + config, project_root = java_config + return config, project_root + package_json_path = find_package_json(config_file_path) pyproject_toml_path = find_closest_config_file("pyproject.toml") if config_file_path is None else None - codeflash_toml_path = find_closest_config_file("codeflash.toml") if config_file_path is None else None - - # Pick the closest toml config (pyproject.toml or codeflash.toml). - # Java projects use codeflash.toml; Python projects use pyproject.toml. - closest_toml_path = None - if pyproject_toml_path and codeflash_toml_path: - closest_toml_path = max(pyproject_toml_path, codeflash_toml_path, key=lambda p: len(p.parent.parts)) - else: - closest_toml_path = pyproject_toml_path or codeflash_toml_path # When both config files exist, prefer the one closer to CWD. # This prevents a parent-directory package.json (e.g., monorepo root) - # from overriding a closer pyproject.toml or codeflash.toml. + # from overriding a closer pyproject.toml. use_package_json = False if package_json_path: - if closest_toml_path is None: + if pyproject_toml_path is None: use_package_json = True else: - # Compare depth: more path parts = closer to CWD = more specific package_json_depth = len(package_json_path.parent.parts) - toml_depth = len(closest_toml_path.parent.parts) + toml_depth = len(pyproject_toml_path.parent.parts) use_package_json = package_json_depth >= toml_depth if use_package_json: @@ -160,7 +169,7 @@ def parse_config_file( if config == {} and lsp_mode: return {}, config_file_path - # Preserve language field if present (important for Java/JS projects using codeflash.toml) + # Preserve language field if present (important for JS/TS projects) # default values: path_keys = ["module-root", "tests-root", "benchmarks-root"] path_list_keys = ["ignore-paths"] diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 5780f4def..ec58a747d 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -554,11 +554,13 @@ def get_all_replay_test_functions( def _get_java_replay_test_functions( - replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path + replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path | str ) -> tuple[dict[Path, list[FunctionToOptimize]], Path]: """Parse Java replay test files to extract functions and trace file path.""" from codeflash.languages.java.replay_test import parse_replay_test_metadata + project_root_path = Path(project_root_path) + trace_file_path: Path | None = None functions: dict[Path, list[FunctionToOptimize]] = defaultdict(list) diff --git a/codeflash/languages/java/build_tools.py b/codeflash/languages/java/build_tools.py index 28db2c9aa..f8a19c693 100644 --- a/codeflash/languages/java/build_tools.py +++ b/codeflash/languages/java/build_tools.py @@ -10,7 +10,8 @@ import xml.etree.ElementTree as ET from dataclasses import dataclass from enum import Enum -from pathlib import Path # noqa: TC003 — used at runtime +from pathlib import Path +from typing import Any logger = logging.getLogger(__name__) @@ -343,6 +344,218 @@ def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]: return tests_run, failures, errors, skipped +def parse_java_project_config(project_root: Path) -> dict[str, Any] | None: + """Parse codeflash config from Maven/Gradle build files. + + Reads codeflash.* properties from pom.xml or gradle.properties, + then fills in defaults from auto-detected build tool conventions. + + Returns None if no Java build tool is detected. + """ + build_tool = detect_build_tool(project_root) + if build_tool == BuildTool.UNKNOWN: + return None + + # Read explicit codeflash properties from build files + user_config: dict[str, str] = {} + if build_tool == BuildTool.MAVEN: + user_config = _read_maven_codeflash_properties(project_root) + elif build_tool == BuildTool.GRADLE: + user_config = _read_gradle_codeflash_properties(project_root) + + # Auto-detect defaults — for multi-module Maven projects, scan module pom.xml files + source_root = find_source_root(project_root) + test_root = find_test_root(project_root) + + if build_tool == BuildTool.MAVEN: + source_from_modules, test_from_modules = _detect_roots_from_maven_modules(project_root) + # Module-level pom.xml declarations are more precise than directory-name heuristics + if source_from_modules is not None: + source_root = source_from_modules + if test_from_modules is not None: + test_root = test_from_modules + + # Build the config dict matching the format expected by the rest of codeflash + config: dict[str, Any] = { + "language": "java", + "module_root": str( + (project_root / user_config["moduleRoot"]).resolve() + if "moduleRoot" in user_config + else (source_root or project_root / "src" / "main" / "java") + ), + "tests_root": str( + (project_root / user_config["testsRoot"]).resolve() + if "testsRoot" in user_config + else (test_root or project_root / "src" / "test" / "java") + ), + "pytest_cmd": "pytest", + "git_remote": user_config.get("gitRemote", "origin"), + "disable_telemetry": user_config.get("disableTelemetry", "false").lower() == "true", + "disable_imports_sorting": False, + "override_fixtures": False, + "benchmark": False, + "formatter_cmds": [], + "ignore_paths": [], + } + + if "ignorePaths" in user_config: + config["ignore_paths"] = [ + str((project_root / p.strip()).resolve()) for p in user_config["ignorePaths"].split(",") if p.strip() + ] + + if "formatterCmds" in user_config: + config["formatter_cmds"] = [cmd.strip() for cmd in user_config["formatterCmds"].split(",") if cmd.strip()] + + return config + + +def _read_maven_codeflash_properties(project_root: Path) -> dict[str, str]: + """Read codeflash.* properties from pom.xml section.""" + pom_path = project_root / "pom.xml" + if not pom_path.exists(): + return {} + + try: + tree = _safe_parse_xml(pom_path) + root = tree.getroot() + ns = {"m": "http://maven.apache.org/POM/4.0.0"} + + result: dict[str, str] = {} + for props in [root.find("m:properties", ns), root.find("properties")]: + if props is None: + continue + for child in props: + tag = child.tag + # Strip Maven namespace prefix + if "}" in tag: + tag = tag.split("}", 1)[1] + if tag.startswith("codeflash.") and child.text: + key = tag[len("codeflash.") :] + result[key] = child.text.strip() + return result + except Exception: + logger.debug("Failed to read codeflash properties from pom.xml", exc_info=True) + return {} + + +def _read_gradle_codeflash_properties(project_root: Path) -> dict[str, str]: + """Read codeflash.* properties from gradle.properties.""" + props_path = project_root / "gradle.properties" + if not props_path.exists(): + return {} + + result: dict[str, str] = {} + try: + with props_path.open("r", encoding="utf-8") as f: + for line in f: + stripped = line.strip() + if stripped.startswith("#") or "=" not in stripped: + continue + key, value = stripped.split("=", 1) + key = key.strip() + if key.startswith("codeflash."): + result[key[len("codeflash.") :]] = value.strip() + return result + except Exception: + logger.debug("Failed to read codeflash properties from gradle.properties", exc_info=True) + return {} + + +def _detect_roots_from_maven_modules(project_root: Path) -> tuple[Path | None, Path | None]: + """Scan Maven module pom.xml files for custom sourceDirectory/testSourceDirectory. + + For multi-module projects like aerospike (client/, test/, benchmarks/), + finds the main source module and test module by parsing each module's build config. + """ + pom_path = project_root / "pom.xml" + if not pom_path.exists(): + return None, None + + try: + tree = _safe_parse_xml(pom_path) + root = tree.getroot() + ns = {"m": "http://maven.apache.org/POM/4.0.0"} + + # Find to get module names + modules: list[str] = [] + for modules_elem in [root.find("m:modules", ns), root.find("modules")]: + if modules_elem is not None: + for mod in modules_elem: + if mod.text: + modules.append(mod.text.strip()) + + if not modules: + return None, None + + # Collect candidate source and test roots with Java file counts + source_candidates: list[tuple[Path, int]] = [] + test_root: Path | None = None + + skip_modules = {"example", "examples", "benchmark", "benchmarks", "demo", "sample", "samples"} + + for module_name in modules: + module_pom = project_root / module_name / "pom.xml" + if not module_pom.exists(): + continue + + # Modules named "test" are test modules, not source modules + is_test_module = "test" in module_name.lower() + + try: + mod_tree = _safe_parse_xml(module_pom) + mod_root = mod_tree.getroot() + + for build in [mod_root.find("m:build", ns), mod_root.find("build")]: + if build is None: + continue + + for src_elem in [build.find("m:sourceDirectory", ns), build.find("sourceDirectory")]: + if src_elem is not None and src_elem.text: + src_text = src_elem.text.replace("${project.basedir}", str(project_root / module_name)) + src_path = Path(src_text) + if not src_path.is_absolute(): + src_path = project_root / module_name / src_path + if src_path.exists(): + if is_test_module and test_root is None: + test_root = src_path + elif module_name.lower() not in skip_modules: + java_count = sum(1 for _ in src_path.rglob("*.java")) + if java_count > 0: + source_candidates.append((src_path, java_count)) + + for test_elem in [build.find("m:testSourceDirectory", ns), build.find("testSourceDirectory")]: + if test_elem is not None and test_elem.text: + test_text = test_elem.text.replace("${project.basedir}", str(project_root / module_name)) + test_path = Path(test_text) + if not test_path.is_absolute(): + test_path = project_root / module_name / test_path + if test_path.exists() and test_root is None: + test_root = test_path + + # Also check standard module layouts + if module_name.lower() not in skip_modules and not is_test_module: + std_src = project_root / module_name / "src" / "main" / "java" + if std_src.exists(): + java_count = sum(1 for _ in std_src.rglob("*.java")) + if java_count > 0: + source_candidates.append((std_src, java_count)) + + if test_root is None: + std_test = project_root / module_name / "src" / "test" / "java" + if std_test.exists() and any(std_test.rglob("*.java")): + test_root = std_test + + except Exception: + continue + + # Pick the source root with the most Java files (likely the main library) + source_root = max(source_candidates, key=lambda x: x[1])[0] if source_candidates else None + return source_root, test_root + + except Exception: + return None, None + + def find_test_root(project_root: Path) -> Path | None: """Find the test root directory for a Java project. diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index 9ecbd613e..914fe7a70 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -785,26 +785,35 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str, if _is_test_annotation(stripped): if not helper_added: helper_added = True - result.append(line) - i += 1 - # Collect any additional annotations - while i < len(lines) and lines[i].strip().startswith("@"): - result.append(lines[i]) + # Check if the @Test line already contains the method signature and opening brace + # (common in compact test styles like replay tests: @Test void replay_foo_0() throws Exception {) + if "{" in line: + # The annotation line IS the method signature — don't look for a separate one + result.append(line) i += 1 - - # Now find the method signature and opening brace - method_lines = [] - while i < len(lines): - method_lines.append(lines[i]) - if "{" in lines[i]: - break + method_lines = [line] + else: + result.append(line) i += 1 - # Add the method signature lines - for ml in method_lines: - result.append(ml) - i += 1 + # Collect any additional annotations + while i < len(lines) and lines[i].strip().startswith("@"): + result.append(lines[i]) + i += 1 + + # Now find the method signature and opening brace + method_lines = [] + while i < len(lines): + method_lines.append(lines[i]) + if "{" in lines[i]: + break + i += 1 + + # Add the method signature lines + for ml in method_lines: + result.append(ml) + i += 1 # Extract the test method name from the method signature test_method_name = _extract_test_method_name(method_lines) diff --git a/codeflash/languages/java/jfr_parser.py b/codeflash/languages/java/jfr_parser.py index 7775378e6..7f3816856 100644 --- a/codeflash/languages/java/jfr_parser.py +++ b/codeflash/languages/java/jfr_parser.py @@ -152,6 +152,8 @@ def _frame_to_key(self, frame: dict[str, Any]) -> str | None: method_name = method.get("name", "") if not class_name or not method_name: return None + # JFR uses / separators (JVM internal format), normalize to dots for package matching + class_name = class_name.replace("/", ".") return f"{class_name}.{method_name}" def _store_method_info(self, key: str, frame: dict[str, Any]) -> None: @@ -159,7 +161,7 @@ def _store_method_info(self, key: str, frame: dict[str, Any]) -> None: return method = frame.get("method", {}) self._method_info[key] = { - "class_name": method.get("type", {}).get("name", ""), + "class_name": method.get("type", {}).get("name", "").replace("/", "."), "method_name": method.get("name", ""), "descriptor": method.get("descriptor", ""), "line_number": str(frame.get("lineNumber", 0)), diff --git a/codeflash/languages/java/replay_test.py b/codeflash/languages/java/replay_test.py index c753bf4fa..415b7a34e 100644 --- a/codeflash/languages/java/replay_test.py +++ b/codeflash/languages/java/replay_test.py @@ -12,9 +12,12 @@ logger = logging.getLogger(__name__) -def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: Path, max_run_count: int = 256) -> int: - """Generate JUnit 5 replay test files from a trace SQLite database. +def generate_replay_tests( + trace_db_path: Path, output_dir: Path, project_root: Path, max_run_count: int = 256, test_framework: str = "junit5" +) -> int: + """Generate JUnit replay test files from a trace SQLite database. + Supports both JUnit 5 (default) and JUnit 4. Returns the number of test files generated. """ if not trace_db_path.exists(): @@ -44,9 +47,10 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P test_methods_code: list[str] = [] class_function_names: list[str] = [] + # Global test counter to avoid duplicate method names for overloaded Java methods + method_name_counters: dict[str, int] = {} for method_name, descriptor in method_list: - # Count invocations for this method count_result = conn.execute( "SELECT COUNT(*) FROM function_calls WHERE classname = ? AND function = ? AND descriptor = ?", (classname, method_name, descriptor), @@ -57,9 +61,14 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P safe_method = _sanitize_identifier(method_name) for i in range(invocation_count): + # Use a global counter per method name to avoid collisions on overloaded methods + test_idx = method_name_counters.get(safe_method, 0) + method_name_counters[safe_method] = test_idx + 1 + escaped_descriptor = descriptor.replace('"', '\\"') + access = "public " if test_framework == "junit4" else "" test_methods_code.append( - f" @Test void replay_{safe_method}_{i}() throws Exception {{\n" + f" @Test {access}void replay_{safe_method}_{test_idx}() throws Exception {{\n" f' helper.replay("{classname}", "{method_name}", ' f'"{escaped_descriptor}", {i});\n' f" }}" @@ -69,18 +78,28 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P # Generate the test file functions_comment = ",".join(class_function_names) + if test_framework == "junit4": + test_imports = "import org.junit.Test;\nimport org.junit.AfterClass;\n" + cleanup_annotation = "@AfterClass" + class_modifier = "public " + else: + test_imports = "import org.junit.jupiter.api.Test;\nimport org.junit.jupiter.api.AfterAll;\n" + cleanup_annotation = "@AfterAll" + class_modifier = "" + test_content = ( f"// codeflash:functions={functions_comment}\n" f"// codeflash:trace_file={trace_db_path.as_posix()}\n" f"// codeflash:classname={classname}\n" f"package codeflash.replay;\n\n" - f"import org.junit.jupiter.api.Test;\n" - f"import org.junit.jupiter.api.AfterAll;\n" + f"{test_imports}" f"import com.codeflash.ReplayHelper;\n\n" - f"class {test_class_name} {{\n" + f"{class_modifier}class {test_class_name} {{\n" f" private static final ReplayHelper helper =\n" f' new ReplayHelper("{trace_db_path.as_posix()}");\n\n' - f" @AfterAll static void cleanup() {{ helper.close(); }}\n\n" + "\n\n".join(test_methods_code) + "\n" + f" {cleanup_annotation} public static void cleanup() {{ helper.close(); }}\n\n" + + "\n\n".join(test_methods_code) + + "\n" "}\n" ) diff --git a/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar b/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar index cfcee9390..48ebc0a96 100644 Binary files a/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar and b/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar differ diff --git a/codeflash/languages/java/tracer.py b/codeflash/languages/java/tracer.py index 7b5a30421..5cc098be5 100644 --- a/codeflash/languages/java/tracer.py +++ b/codeflash/languages/java/tracer.py @@ -14,6 +14,39 @@ logger = logging.getLogger(__name__) +GRACEFUL_SHUTDOWN_WAIT = 5 # seconds to wait after SIGTERM before SIGKILL + + +def _run_java_with_graceful_timeout( + java_command: list[str], env: dict[str, str], timeout: int, stage_name: str +) -> None: + """Run a Java command with graceful timeout handling. + + Sends SIGTERM first (allowing JFR dump and shutdown hooks to run), + then SIGKILL if the process doesn't exit within GRACEFUL_SHUTDOWN_WAIT seconds. + """ + if not timeout: + subprocess.run(java_command, env=env, check=False) + return + + import signal + + proc = subprocess.Popen(java_command, env=env) + try: + proc.wait(timeout=timeout) + except subprocess.TimeoutExpired: + logger.warning( + "%s stage timed out after %d seconds, sending SIGTERM for graceful shutdown...", stage_name, timeout + ) + proc.send_signal(signal.SIGTERM) + try: + proc.wait(timeout=GRACEFUL_SHUTDOWN_WAIT) + except subprocess.TimeoutExpired: + logger.warning("%s stage did not exit after SIGTERM, sending SIGKILL", stage_name) + proc.kill() + proc.wait() + + # --add-opens flags needed for Kryo serialization on Java 16+ ADD_OPENS_FLAGS = ( "--add-opens=java.base/java.util=ALL-UNNAMED " @@ -48,10 +81,7 @@ def trace( # Stage 1: JFR Profiling logger.info("Stage 1: Running JFR profiling...") jfr_env = self.build_jfr_env(jfr_file) - try: - subprocess.run(java_command, env=jfr_env, check=False, timeout=timeout or None) - except subprocess.TimeoutExpired: - logger.warning("JFR profiling stage timed out after %d seconds", timeout) + _run_java_with_graceful_timeout(java_command, jfr_env, timeout, "JFR profiling") if not jfr_file.exists(): logger.warning("JFR file was not created at %s", jfr_file) @@ -62,10 +92,7 @@ def trace( trace_db_path, packages, project_root=project_root, max_function_count=max_function_count, timeout=timeout ) agent_env = self.build_agent_env(config_path) - try: - subprocess.run(java_command, env=agent_env, check=False, timeout=timeout or None) - except subprocess.TimeoutExpired: - logger.warning("Argument capture stage timed out after %d seconds", timeout) + _run_java_with_graceful_timeout(java_command, agent_env, timeout, "Argument capture") if not trace_db_path.exists(): logger.error("Trace database was not created at %s", trace_db_path) @@ -95,7 +122,12 @@ def create_tracer_config( def build_jfr_env(self, jfr_file: Path) -> dict[str, str]: env = os.environ.copy() - jfr_opts = f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true" + # Use profile settings with increased sampling frequency (1ms instead of default 10ms) + # This captures more samples for short-running programs + jfr_opts = ( + f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true" + ",jdk.ExecutionSample#period=1ms" + ) existing = env.get("JAVA_TOOL_OPTIONS", "") env["JAVA_TOOL_OPTIONS"] = f"{existing} {jfr_opts}".strip() return env @@ -153,6 +185,7 @@ def run_java_tracer( max_function_count: int = 256, timeout: int = 0, max_run_count: int = 256, + test_framework: str = "junit5", ) -> tuple[Path, Path, int]: """High-level entry point: trace a Java command and generate replay tests. @@ -169,7 +202,11 @@ def run_java_tracer( ) test_count = generate_replay_tests( - trace_db_path=trace_db, output_dir=output_dir, project_root=project_root, max_run_count=max_run_count + trace_db_path=trace_db, + output_dir=output_dir, + project_root=project_root, + max_run_count=max_run_count, + test_framework=test_framework, ) return trace_db, jfr_file, test_count diff --git a/codeflash/setup/config_writer.py b/codeflash/setup/config_writer.py index 0889690d5..4616ccf5f 100644 --- a/codeflash/setup/config_writer.py +++ b/codeflash/setup/config_writer.py @@ -8,7 +8,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import tomlkit @@ -38,7 +38,7 @@ def write_config(detected: DetectedProject, config: CodeflashConfig | None = Non if detected.language == "python": return _write_pyproject_toml(detected.project_root, config) if detected.language == "java": - return _write_codeflash_toml(detected.project_root, config) + return _write_java_build_config(detected.project_root, config) return _write_package_json(detected.project_root, config) @@ -92,10 +92,10 @@ def _write_pyproject_toml(project_root: Path, config: CodeflashConfig) -> tuple[ return False, f"Failed to write pyproject.toml: {e}" -def _write_codeflash_toml(project_root: Path, config: CodeflashConfig) -> tuple[bool, str]: - """Write config to codeflash.toml [tool.codeflash] section for Java projects. +def _write_java_build_config(project_root: Path, config: CodeflashConfig) -> tuple[bool, str]: + """Write codeflash config to pom.xml properties or gradle.properties. - Creates codeflash.toml if it doesn't exist. + Only writes non-default values. Standard Maven/Gradle layouts need no config. Args: project_root: Project root directory. @@ -105,40 +105,110 @@ def _write_codeflash_toml(project_root: Path, config: CodeflashConfig) -> tuple[ Tuple of (success, message). """ - codeflash_toml_path = project_root / "codeflash.toml" + config_dict = config.to_pyproject_dict() - try: - # Load existing or create new - if codeflash_toml_path.exists(): - with codeflash_toml_path.open("rb") as f: - doc = tomlkit.parse(f.read()) - else: - doc = tomlkit.document() + # Filter out default values — only write overrides + defaults = {"module-root": "src/main/java", "tests-root": "src/test/java", "language": "java"} + non_default = {k: v for k, v in config_dict.items() if k not in defaults or str(v) != defaults.get(k)} + # Remove empty lists and False booleans + non_default = {k: v for k, v in non_default.items() if v not in ([], False, "", None)} - # Ensure [tool] section exists - if "tool" not in doc: - doc["tool"] = tomlkit.table() + if not non_default: + return True, "Standard Maven/Gradle layout detected — no config needed" - # Create codeflash section - codeflash_table = tomlkit.table() - codeflash_table.add(tomlkit.comment("Codeflash configuration for Java - https://docs.codeflash.ai")) + pom_path = project_root / "pom.xml" + if pom_path.exists(): + return _write_maven_properties(pom_path, non_default) - # Add config values - config_dict = config.to_pyproject_dict() - for key, value in config_dict.items(): - codeflash_table[key] = value + gradle_props_path = project_root / "gradle.properties" + return _write_gradle_properties(gradle_props_path, non_default) - # Update the document - doc["tool"]["codeflash"] = codeflash_table - # Write back - with codeflash_toml_path.open("w", encoding="utf8") as f: - f.write(tomlkit.dumps(doc)) +def _write_maven_properties(pom_path: Path, config: dict[str, Any]) -> tuple[bool, str]: + """Add codeflash.* properties to pom.xml section.""" + import xml.etree.ElementTree as ET - return True, f"Config saved to {codeflash_toml_path}" + try: + tree = ET.parse(str(pom_path)) + root = tree.getroot() + ns = {"m": "http://maven.apache.org/POM/4.0.0"} + + # Find or create + properties = root.find("m:properties", ns) or root.find("properties") + if properties is None: + properties = ET.SubElement(root, "properties") + + # Convert kebab-case keys to camelCase for Maven convention + key_map = { + "module-root": "moduleRoot", + "tests-root": "testsRoot", + "git-remote": "gitRemote", + "disable-telemetry": "disableTelemetry", + "ignore-paths": "ignorePaths", + "formatter-cmds": "formatterCmds", + } + + for key, value in config.items(): + maven_key = f"codeflash.{key_map.get(key, key)}" + if isinstance(value, list): + value = ",".join(str(v) for v in value) + elif isinstance(value, bool): + value = str(value).lower() + else: + value = str(value) + + existing = properties.find(maven_key) + if existing is None: + elem = ET.SubElement(properties, maven_key) + elem.text = value + else: + existing.text = value + + tree.write(str(pom_path), xml_declaration=True, encoding="UTF-8") + return True, f"Config saved to {pom_path} " except Exception as e: - return False, f"Failed to write codeflash.toml: {e}" + return False, f"Failed to write Maven properties: {e}" + + +def _write_gradle_properties(props_path: Path, config: dict[str, Any]) -> tuple[bool, str]: + """Add codeflash.* entries to gradle.properties.""" + key_map = { + "module-root": "moduleRoot", + "tests-root": "testsRoot", + "git-remote": "gitRemote", + "disable-telemetry": "disableTelemetry", + "ignore-paths": "ignorePaths", + "formatter-cmds": "formatterCmds", + } + + try: + lines = [] + if props_path.exists(): + lines = props_path.read_text(encoding="utf-8").splitlines() + + # Remove existing codeflash.* lines + lines = [line for line in lines if not line.strip().startswith("codeflash.")] + + # Add new config + if lines and lines[-1].strip(): + lines.append("") + lines.append("# Codeflash configuration — https://docs.codeflash.ai") + for key, value in config.items(): + gradle_key = f"codeflash.{key_map.get(key, key)}" + if isinstance(value, list): + value = ",".join(str(v) for v in value) + elif isinstance(value, bool): + value = str(value).lower() + else: + value = str(value) + lines.append(f"{gradle_key}={value}") + + props_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + return True, f"Config saved to {props_path}" + + except Exception as e: + return False, f"Failed to write gradle.properties: {e}" def _write_package_json(project_root: Path, config: CodeflashConfig) -> tuple[bool, str]: @@ -206,7 +276,7 @@ def remove_config(project_root: Path, language: str) -> tuple[bool, str]: if language == "python": return _remove_from_pyproject(project_root) if language == "java": - return _remove_from_codeflash_toml(project_root) + return _remove_java_build_config(project_root) return _remove_from_package_json(project_root) @@ -235,29 +305,45 @@ def _remove_from_pyproject(project_root: Path) -> tuple[bool, str]: return False, f"Failed to remove config: {e}" -def _remove_from_codeflash_toml(project_root: Path) -> tuple[bool, str]: - """Remove [tool.codeflash] section from codeflash.toml.""" - codeflash_toml_path = project_root / "codeflash.toml" - - if not codeflash_toml_path.exists(): - return True, "No codeflash.toml found" - - try: - with codeflash_toml_path.open("rb") as f: - doc = tomlkit.parse(f.read()) - - if "tool" in doc and "codeflash" in doc["tool"]: - del doc["tool"]["codeflash"] - - with codeflash_toml_path.open("w", encoding="utf8") as f: - f.write(tomlkit.dumps(doc)) - - return True, "Removed [tool.codeflash] section from codeflash.toml" - - return True, "No codeflash config found in codeflash.toml" - - except Exception as e: - return False, f"Failed to remove config: {e}" +def _remove_java_build_config(project_root: Path) -> tuple[bool, str]: + """Remove codeflash.* properties from pom.xml or gradle.properties.""" + # Try gradle.properties first (simpler) + gradle_props = project_root / "gradle.properties" + if gradle_props.exists(): + try: + lines = gradle_props.read_text(encoding="utf-8").splitlines() + filtered = [ + line + for line in lines + if not line.strip().startswith("codeflash.") + and line.strip() != "# Codeflash configuration — https://docs.codeflash.ai" + ] + gradle_props.write_text("\n".join(filtered) + "\n", encoding="utf-8") + return True, "Removed codeflash properties from gradle.properties" + except Exception as e: + return False, f"Failed to remove config from gradle.properties: {e}" + + # Try pom.xml + pom_path = project_root / "pom.xml" + if pom_path.exists(): + try: + import xml.etree.ElementTree as ET + + tree = ET.parse(str(pom_path)) + root = tree.getroot() + ns = {"m": "http://maven.apache.org/POM/4.0.0"} + for properties in [root.find("m:properties", ns), root.find("properties")]: + if properties is None: + continue + to_remove = [child for child in properties if child.tag.split("}")[-1].startswith("codeflash.")] + for elem in to_remove: + properties.remove(elem) + tree.write(str(pom_path), xml_declaration=True, encoding="UTF-8") + return True, "Removed codeflash properties from pom.xml" + except Exception as e: + return False, f"Failed to remove config from pom.xml: {e}" + + return True, "No Java build config found" def _remove_from_package_json(project_root: Path) -> tuple[bool, str]: diff --git a/codeflash/setup/detector.py b/codeflash/setup/detector.py index defe1a22d..06d690190 100644 --- a/codeflash/setup/detector.py +++ b/codeflash/setup/detector.py @@ -886,20 +886,24 @@ def has_existing_config(project_root: Path) -> tuple[bool, str | None]: Returns: Tuple of (has_config, config_file_type). - config_file_type is "pyproject.toml", "codeflash.toml", "package.json", or None. + config_file_type is "pyproject.toml", "pom.xml", "build.gradle", "package.json", or None. """ - # Check TOML config files (pyproject.toml, codeflash.toml) - for toml_filename in ("pyproject.toml", "codeflash.toml"): - toml_path = project_root / toml_filename - if toml_path.exists(): - try: - with toml_path.open("rb") as f: - data = tomlkit.parse(f.read()) - if "tool" in data and "codeflash" in data["tool"]: - return True, toml_filename - except Exception: - pass + # Check pyproject.toml (Python projects) + pyproject_path = project_root / "pyproject.toml" + if pyproject_path.exists(): + try: + with pyproject_path.open("rb") as f: + data = tomlkit.parse(f.read()) + if "tool" in data and "codeflash" in data["tool"]: + return True, "pyproject.toml" + except Exception: + pass + + # Check Java build files — Java projects store config in pom.xml properties or gradle.properties + for build_file in ("pom.xml", "build.gradle", "build.gradle.kts"): + if (project_root / build_file).exists(): + return True, build_file # Check package.json package_json_path = project_root / "package.json" diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 84f58e9da..5f8a1a4ab 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -38,7 +38,7 @@ def _detect_non_python_language(args: Namespace | None) -> Language | None: - """Detect if the project uses a non-Python language from --file or config. + """Detect if the project uses a non-Python language from --file or build files. Returns a Language enum value if non-Python detected, None otherwise. """ @@ -66,15 +66,23 @@ def _detect_non_python_language(args: Namespace | None) -> Language | None: except Exception: pass - # Method 2: Check project config for language field + # Method 2: Detect Java from build files (pom.xml / build.gradle) + try: + from codeflash.languages.java.build_tools import BuildTool, detect_build_tool + + cwd = Path.cwd() + if detect_build_tool(cwd) != BuildTool.UNKNOWN: + return Language.JAVA + except Exception: + pass + + # Method 3: Check config file for language field (JS/TS via package.json) try: from codeflash.code_utils.config_parser import parse_config_file config_file = getattr(args, "config_file_path", None) if args else None config, _ = parse_config_file(config_file) lang_str = config.get("language", "") - if lang_str == "java": - return Language.JAVA if lang_str in ("javascript", "typescript"): return Language(lang_str) except Exception: @@ -336,8 +344,12 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: max_function_count = getattr(config, "max_function_count", 256) timeout = int(getattr(config, "timeout", None) or getattr(config, "tracer_timeout", 0) or 0) + console.print("[bold]Java project detected[/]") + console.print(f" Project root: {project_root}") + console.print(f" Module root: {getattr(config, 'module_root', '?')}") + console.print(f" Tests root: {getattr(config, 'tests_root', '?')}") + from codeflash.code_utils.code_utils import get_run_tmp_file - from codeflash.languages.java.build_tools import find_test_root from codeflash.languages.java.tracer import JavaTracer, run_java_tracer tracer = JavaTracer() @@ -347,12 +359,16 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: trace_db_path = get_run_tmp_file(Path("java_trace.db")) - # Place replay tests in the project's test source tree so Maven/Gradle can compile them - test_root = find_test_root(project_root) - if test_root: - output_dir = test_root / "codeflash" / "replay" + # Place replay tests in the project's test source tree so Maven/Gradle can compile them. + # Use the config's tests_root (correctly resolved for multi-module projects) not find_test_root(). + tests_root = Path(getattr(config, "tests_root", "")) + if tests_root.is_dir(): + output_dir = tests_root / "codeflash" / "replay" else: - output_dir = project_root / "src" / "test" / "java" / "codeflash" / "replay" + from codeflash.languages.java.build_tools import find_test_root + + test_root = find_test_root(project_root) + output_dir = (test_root or project_root / "src" / "test" / "java") / "codeflash" / "replay" output_dir.mkdir(parents=True, exist_ok=True) # Remaining args after our flags are the Java command @@ -364,6 +380,12 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: sys.exit(1) java_command = remaining + # Detect test framework for replay test generation + from codeflash.languages.java.config import detect_java_project + + java_config = detect_java_project(project_root) + test_framework = java_config.test_framework if java_config else "junit5" + trace_db, jfr_file, test_count = run_java_tracer( java_command=java_command, trace_db_path=trace_db_path, @@ -372,6 +394,7 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: output_dir=output_dir, max_function_count=max_function_count, timeout=timeout, + test_framework=test_framework, ) console.print(f"[bold green]Java tracing complete:[/] {test_count} replay test files generated") diff --git a/docs/configuration/java.mdx b/docs/configuration/java.mdx index 9d110fc55..720e5e091 100644 --- a/docs/configuration/java.mdx +++ b/docs/configuration/java.mdx @@ -1,101 +1,112 @@ --- title: "Java Configuration" -description: "Configure Codeflash for Java projects using codeflash.toml" +description: "Configure Codeflash for Java projects — zero config for standard layouts" icon: "java" -sidebarTitle: "Java (codeflash.toml)" +sidebarTitle: "Java (pom.xml / Gradle)" keywords: [ "configuration", - "codeflash.toml", "java", "maven", "gradle", "junit", + "pom.xml", + "gradle.properties", + "zero-config", ] --- # Java Configuration -Codeflash stores its configuration in `codeflash.toml` under the `[tool.codeflash]` section. +**Standard Maven/Gradle projects need zero configuration.** Codeflash auto-detects your project structure from `pom.xml` or `build.gradle` — no config file is required. -## Full Reference - -```toml -[tool.codeflash] -# Required -module-root = "src/main/java" -tests-root = "src/test/java" -language = "java" - -# Optional -test-framework = "junit5" # "junit5", "junit4", or "testng" -disable-telemetry = false -git-remote = "origin" -ignore-paths = ["src/main/java/generated/"] -``` - -All file paths are relative to the directory containing `codeflash.toml`. - - -Codeflash auto-detects most settings from your project structure. Running `codeflash init` will set up the correct config — manual configuration is usually not needed. - +For projects with non-standard layouts, you can add `codeflash.*` properties to your existing `pom.xml` or `gradle.properties`. ## Auto-Detection -When you run `codeflash init`, Codeflash inspects your project and auto-detects: +Codeflash inspects your build files and auto-detects: | Setting | Detection logic | |---------|----------------| -| `module-root` | Looks for `src/main/java` (Maven/Gradle standard layout) | -| `tests-root` | Looks for `src/test/java`, `test/`, `tests/` | -| `language` | Detected from build files (`pom.xml`, `build.gradle`) and `.java` files | -| `test-framework` | Checks build file dependencies for JUnit 5, JUnit 4, or TestNG | - -## Required Options - -- **`module-root`**: The source directory to optimize. Only code under this directory is discovered for optimization. For standard Maven/Gradle projects, this is `src/main/java`. -- **`tests-root`**: The directory where your tests are located. Codeflash discovers existing tests and places generated replay tests here. -- **`language`**: Must be set to `"java"` for Java projects. +| **Language** | Presence of `pom.xml` or `build.gradle` / `build.gradle.kts` | +| **Source root** | `src/main/java` (standard), or `` in `pom.xml`, or Gradle `sourceSets` | +| **Test root** | `src/test/java` (standard), or `` in `pom.xml` | +| **Test framework** | Checks build file dependencies for JUnit 5, JUnit 4, or TestNG | +| **Java version** | ``, `` in `pom.xml` | -## Optional Options +### Multi-module Maven projects -- **`test-framework`**: Test framework. Auto-detected from build dependencies. Supported values: `"junit5"` (default), `"junit4"`, `"testng"`. -- **`disable-telemetry`**: Disable anonymized telemetry. Defaults to `false`. -- **`git-remote`**: Git remote for pull requests. Defaults to `"origin"`. -- **`ignore-paths`**: Paths within `module-root` to skip during optimization. +For multi-module projects, Codeflash scans each module's `pom.xml` for `` and `` declarations. It picks the module with the most Java source files as the main source root, and identifies test modules by name. -## Multi-Module Projects - -For multi-module Maven/Gradle projects, place `codeflash.toml` at the project root and set `module-root` to the module you want to optimize: +For example, with this layout: ```text my-project/ -|- client/ -| |- src/main/java/com/example/client/ -| |- src/test/java/com/example/client/ -|- server/ -| |- src/main/java/com/example/server/ -|- pom.xml -|- codeflash.toml +|- client/ ← main library (most .java files) +| |- src/com/example/ +| |- pom.xml ← ${project.basedir}/src +|- test/ ← test module +| |- src/com/example/ +| |- pom.xml ← ${project.basedir}/src +|- benchmarks/ ← skipped (benchmark module) +|- pom.xml ← client, test, benchmarks ``` -```toml -[tool.codeflash] -module-root = "client/src/main/java" -tests-root = "client/src/test/java" -language = "java" +Codeflash auto-detects `client/src` as the source root and `test/src` as the test root — no manual configuration needed. + +## Custom Configuration + +If auto-detection doesn't match your project layout, add `codeflash.*` properties to your build files. + + + + +Add properties to your `pom.xml` `` section: + +```xml + + + client/src + test/src + true + upstream + src/main/java/generated/,src/main/java/proto/ + ``` -For non-standard layouts (like the Aerospike client where source is under `client/src/`), adjust paths accordingly: +This follows the same pattern as SonarQube (`sonar.sources`), JaCoCo, and other Java tools — config lives in the build file, not a separate tool-specific file. + + + + +Add properties to `gradle.properties`: -```toml -[tool.codeflash] -module-root = "client/src" -tests-root = "test/src" -language = "java" +```properties +# Only set values that differ from auto-detected defaults +codeflash.moduleRoot=lib/src/main/java +codeflash.testsRoot=lib/src/test/java +codeflash.disableTelemetry=true +codeflash.gitRemote=upstream +codeflash.ignorePaths=src/main/java/generated/ ``` -## Tracer Options + + + +## Available Properties + +All properties are optional — only set values that differ from auto-detected defaults. + +| Property | Description | Default | +|----------|------------|---------| +| `codeflash.moduleRoot` | Source directory to optimize | Auto-detected from `` or `src/main/java` | +| `codeflash.testsRoot` | Test directory | Auto-detected from `` or `src/test/java` | +| `codeflash.disableTelemetry` | Disable anonymized telemetry | `false` | +| `codeflash.gitRemote` | Git remote for pull requests | `origin` | +| `codeflash.ignorePaths` | Comma-separated paths to skip during optimization | Empty | +| `codeflash.formatterCmds` | Comma-separated formatter commands (`$file` = file path) | Empty | + +## Tracer CLI Options When using `codeflash optimize` to trace a Java program, these CLI options are available: @@ -111,9 +122,9 @@ Example with timeout: codeflash optimize --timeout 30 java -jar target/my-app.jar --app-args ``` -## Example +## Examples -### Standard Maven project +### Standard Maven project (zero config) ```text my-app/ @@ -124,17 +135,14 @@ my-app/ | |- test/java/com/example/ | |- AppTest.java |- pom.xml -|- codeflash.toml ``` -```toml -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" -language = "java" +Just run: +```bash +codeflash optimize java -jar target/my-app.jar ``` -### Gradle project +### Standard Gradle project (zero config) ```text my-lib/ @@ -142,12 +150,55 @@ my-lib/ | |- main/java/com/example/ | |- test/java/com/example/ |- build.gradle -|- codeflash.toml ``` -```toml -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" -language = "java" +Just run: +```bash +codeflash optimize java -cp build/classes/java/main com.example.Main ``` + +### Non-standard layout (with config) + +```text +aerospike-client-java/ +|- client/ +| |- src/com/aerospike/client/ ← source here (not src/main/java) +| |- pom.xml +|- test/ +| |- src/com/aerospike/test/ ← tests here +| |- pom.xml +|- pom.xml +``` + +If auto-detection doesn't pick up the right modules, add to the root `pom.xml`: + +```xml + + client/src + test/src + +``` + + +In most cases, even non-standard multi-module layouts are auto-detected correctly from `` and `` in each module's `pom.xml`. Only add manual config if auto-detection gets it wrong. + + +## FAQ + + + + No. Codeflash auto-detects Java projects from `pom.xml` or `build.gradle`. No initialization step or config file is needed for standard layouts. + + + + Codeflash reads config from your existing build files — `pom.xml` `` for Maven, `gradle.properties` for Gradle. No separate config file is created. + + + + Add `` and `` properties to your `pom.xml` or `gradle.properties`. These override auto-detection. + + + + Codeflash scans each module's `pom.xml` for `` and ``. It picks the module with the most Java files as the source root (skipping modules named `examples`, `benchmarks`, etc.) and identifies `test` modules for the test root. + + diff --git a/docs/getting-started/java-installation.mdx b/docs/getting-started/java-installation.mdx index a75e1f0b7..fb2a88ef2 100644 --- a/docs/getting-started/java-installation.mdx +++ b/docs/getting-started/java-installation.mdx @@ -12,10 +12,11 @@ keywords: "junit", "junit5", "tracing", + "zero-config", ] --- -Codeflash supports Java projects using Maven or Gradle build systems. It uses a two-stage tracing approach to capture method arguments and profiling data from running Java programs, then optimizes the hottest functions. +Codeflash supports Java projects using Maven or Gradle build systems. **No configuration file is needed** — Codeflash auto-detects your project structure from `pom.xml` or `build.gradle`. ### Prerequisites @@ -23,7 +24,7 @@ Before installing Codeflash, ensure you have: 1. **Java 11 or above** installed 2. **Maven or Gradle** as your build tool -3. **A Java project** with source code under a standard directory layout +3. **A Java project** with source code Good to have (optional): @@ -45,61 +46,48 @@ uv pip install codeflash ``` - + Navigate to your Java project root (where `pom.xml` or `build.gradle` is) and run: ```bash -codeflash init +codeflash optimize java -jar target/my-app.jar ``` -This will: -- Detect your build tool (Maven/Gradle) -- Find your source and test directories -- Create a `codeflash.toml` configuration file +That's it — no `init` step, no config file. Codeflash detects Maven/Gradle automatically and infers source and test directories from your build files. - - +Codeflash will: +1. Profile your program using JFR (Java Flight Recorder) +2. Capture method arguments using a bytecode instrumentation agent +3. Generate JUnit replay tests from the captured data +4. Rank functions by performance impact +5. Optimize the most impactful functions -Check that the configuration looks correct: + + -```bash -cat codeflash.toml -``` + +**Zero config for standard projects.** If your project uses the standard Maven/Gradle layout (`src/main/java`, `src/test/java`), everything is auto-detected. For non-standard layouts, see the [configuration guide](/configuration/java). + -You should see something like: +## Usage examples -```toml -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" -language = "java" +**Trace and optimize a JAR application:** +```bash +codeflash optimize java -jar target/my-app.jar --app-args ``` - - - -Trace and optimize a running Java program: - +**Optimize a specific file and function:** ```bash -codeflash optimize java -jar target/my-app.jar +codeflash --file src/main/java/com/example/Utils.java --function computeHash ``` -Or with Maven: - +**Trace a long-running program with a timeout:** ```bash -codeflash optimize mvn exec:java -Dexec.mainClass="com.example.Main" +codeflash optimize --timeout 30 java -jar target/my-server.jar ``` -Codeflash will: -1. Profile your program using JFR (Java Flight Recorder) -2. Capture method arguments using a bytecode instrumentation agent -3. Generate JUnit replay tests from the captured data -4. Rank functions by performance impact -5. Optimize the most impactful functions - - - +Each tracing stage runs for at most 30 seconds, then the captured data is processed. ## How it works diff --git a/tests/scripts/end_to_end_test_java_tracer.py b/tests/scripts/end_to_end_test_java_tracer.py index e904a4e98..5555b041c 100644 --- a/tests/scripts/end_to_end_test_java_tracer.py +++ b/tests/scripts/end_to_end_test_java_tracer.py @@ -59,6 +59,7 @@ def run_test(expected_improvement_pct: int) -> bool: env = os.environ.copy() env["PYTHONIOENCODING"] = "utf-8" + env["PYTHONUNBUFFERED"] = "1" logging.info(f"Running command: {' '.join(command)}") logging.info(f"Working directory: {fixture_dir}") process = subprocess.Popen( @@ -73,13 +74,11 @@ def run_test(expected_improvement_pct: int) -> bool: output = [] for line in process.stdout: - logging.info(line.strip()) + print(line, end="", flush=True) output.append(line) return_code = process.wait() stdout = "".join(output) - if return_code != 0: - logging.error(f"Full output:\n{stdout}") if return_code != 0: logging.error(f"Command returned exit code {return_code}") @@ -90,7 +89,7 @@ def run_test(expected_improvement_pct: int) -> bool: logging.error("Failed to find replay test generation message") return False - # Validate: replay tests were discovered + # Validate: replay tests were discovered (global count) replay_match = re.search(r"Discovered \d+ existing unit tests? and (\d+) replay tests?", stdout) if not replay_match: logging.error("Failed to find replay test discovery message") @@ -101,6 +100,17 @@ def run_test(expected_improvement_pct: int) -> bool: return False logging.info(f"Replay tests discovered: {num_replay}") + # Validate: replay test files were used per-function + replay_file_match = re.search(r"Discovered \d+ existing unit test files?, (\d+) replay test files?", stdout) + if not replay_file_match: + logging.error("Failed to find per-function replay test file discovery message") + return False + num_replay_files = int(replay_file_match.group(1)) + if num_replay_files == 0: + logging.error("No replay test files discovered per-function") + return False + logging.info(f"Replay test files per-function: {num_replay_files}") + # Validate: at least one optimization was found if "⚡️ Optimization successful! 📄 " not in stdout: logging.error("Failed to find optimization success message") diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index 12259b339..33825db4d 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -149,8 +149,8 @@ def build_command( if config.function_name: base_command.extend(["--function", config.function_name]) - # Check if config exists (pyproject.toml or codeflash.toml) - if so, don't override it - has_codeflash_config = (cwd / "codeflash.toml").exists() + # Check if config exists (pyproject.toml, pom.xml, build.gradle) - if so, don't override it + has_codeflash_config = (cwd / "pom.xml").exists() or (cwd / "build.gradle").exists() or (cwd / "build.gradle.kts").exists() if not has_codeflash_config: pyproject_path = cwd / "pyproject.toml" if pyproject_path.exists(): diff --git a/tests/test_languages/fixtures/java_maven/codeflash.toml b/tests/test_languages/fixtures/java_maven/codeflash.toml deleted file mode 100644 index ecd20a562..000000000 --- a/tests/test_languages/fixtures/java_maven/codeflash.toml +++ /dev/null @@ -1,5 +0,0 @@ -# Codeflash configuration for Java project - -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" diff --git a/tests/test_languages/fixtures/java_tracer_e2e/codeflash.toml b/tests/test_languages/fixtures/java_tracer_e2e/codeflash.toml deleted file mode 100644 index a501ef8cb..000000000 --- a/tests/test_languages/fixtures/java_tracer_e2e/codeflash.toml +++ /dev/null @@ -1,6 +0,0 @@ -# Codeflash configuration for Java project - -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" -language = "java" diff --git a/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java index 9b6078000..7beb2a4ea 100644 --- a/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java +++ b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java @@ -36,20 +36,30 @@ public int instanceMethod(int x, int y) { } public static void main(String[] args) { - // Exercise the methods so the tracer can capture invocations - System.out.println("computeSum(100) = " + computeSum(100)); - System.out.println("computeSum(50) = " + computeSum(50)); + // Run methods with large inputs so JFR can capture CPU samples. + // Small inputs finish too fast (<1ms) for JFR's 10ms sampling interval. + for (int round = 0; round < 1000; round++) { + computeSum(100_000); + repeatString("hello world ", 1000); + + List nums = new ArrayList<>(); + for (int i = 1; i <= 10_000; i++) nums.add(i); + filterEvens(nums); + Workload w = new Workload(); + w.instanceMethod(100_000, 42); + } + + // Also call with small inputs for variety in traced args + System.out.println("computeSum(100) = " + computeSum(100)); System.out.println("repeatString(\"ab\", 3) = " + repeatString("ab", 3)); - System.out.println("repeatString(\"x\", 5) = " + repeatString("x", 5)); - List nums = new ArrayList<>(); - for (int i = 1; i <= 10; i++) nums.add(i); - System.out.println("filterEvens(1..10) = " + filterEvens(nums)); + List small = new ArrayList<>(); + for (int i = 1; i <= 10; i++) small.add(i); + System.out.println("filterEvens(1..10) = " + filterEvens(small)); Workload w = new Workload(); System.out.println("instanceMethod(5, 3) = " + w.instanceMethod(5, 3)); - System.out.println("instanceMethod(10, 2) = " + w.instanceMethod(10, 2)); System.out.println("Workload complete."); } diff --git a/tests/test_languages/test_java/test_java_config_detection.py b/tests/test_languages/test_java/test_java_config_detection.py new file mode 100644 index 000000000..ebb8653af --- /dev/null +++ b/tests/test_languages/test_java/test_java_config_detection.py @@ -0,0 +1,444 @@ +"""Tests for Java project auto-detection from Maven/Gradle build files. + +Tests that codeflash can detect Java projects and infer module-root, +tests-root, and other config from pom.xml / build.gradle / gradle.properties +without requiring a standalone codeflash.toml config file. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from codeflash.languages.java.build_tools import ( + BuildTool, + detect_build_tool, + find_source_root, + find_test_root, + parse_java_project_config, +) + + +# --------------------------------------------------------------------------- +# Build tool detection +# --------------------------------------------------------------------------- + + +class TestDetectBuildTool: + def test_detect_maven(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + assert detect_build_tool(tmp_path) == BuildTool.MAVEN + + def test_detect_gradle(self, tmp_path: Path) -> None: + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + assert detect_build_tool(tmp_path) == BuildTool.GRADLE + + def test_detect_gradle_kts(self, tmp_path: Path) -> None: + (tmp_path / "build.gradle.kts").write_text("", encoding="utf-8") + assert detect_build_tool(tmp_path) == BuildTool.GRADLE + + def test_maven_takes_priority_over_gradle(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + assert detect_build_tool(tmp_path) == BuildTool.MAVEN + + def test_unknown_when_no_build_file(self, tmp_path: Path) -> None: + assert detect_build_tool(tmp_path) == BuildTool.UNKNOWN + + def test_detect_maven_in_parent(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + child = tmp_path / "module" + child.mkdir() + assert detect_build_tool(child) == BuildTool.MAVEN + + +# --------------------------------------------------------------------------- +# Source / test root detection (standard layouts) +# --------------------------------------------------------------------------- + + +class TestFindSourceRoot: + def test_standard_maven_layout(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + assert find_source_root(tmp_path) == src + + def test_fallback_to_src_with_java_files(self, tmp_path: Path) -> None: + src = tmp_path / "src" + src.mkdir() + (src / "App.java").write_text("class App {}", encoding="utf-8") + assert find_source_root(tmp_path) == src + + def test_returns_none_when_no_source(self, tmp_path: Path) -> None: + assert find_source_root(tmp_path) is None + + +class TestFindTestRoot: + def test_standard_maven_layout(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + test = tmp_path / "src" / "test" / "java" + test.mkdir(parents=True) + assert find_test_root(tmp_path) == test + + def test_fallback_to_test_dir(self, tmp_path: Path) -> None: + test = tmp_path / "test" + test.mkdir() + assert find_test_root(tmp_path) == test + + def test_fallback_to_tests_dir(self, tmp_path: Path) -> None: + tests = tmp_path / "tests" + tests.mkdir() + assert find_test_root(tmp_path) == tests + + def test_returns_none_when_no_test_dir(self, tmp_path: Path) -> None: + assert find_test_root(tmp_path) is None + + +# --------------------------------------------------------------------------- +# parse_java_project_config — standard layouts +# --------------------------------------------------------------------------- + + +class TestParseJavaProjectConfigStandard: + def test_standard_maven_project(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + test = tmp_path / "src" / "test" / "java" + test.mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["language"] == "java" + assert config["module_root"] == str(src) + assert config["tests_root"] == str(test) + + def test_standard_gradle_project(self, tmp_path: Path) -> None: + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + test = tmp_path / "src" / "test" / "java" + test.mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["language"] == "java" + assert config["module_root"] == str(src) + assert config["tests_root"] == str(test) + + def test_returns_none_for_non_java_project(self, tmp_path: Path) -> None: + assert parse_java_project_config(tmp_path) is None + + def test_defaults_when_dirs_missing(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + config = parse_java_project_config(tmp_path) + assert config is not None + # Falls back to default paths even if they don't exist + assert str(tmp_path / "src" / "main" / "java") == config["module_root"] + assert config["language"] == "java" + + +# --------------------------------------------------------------------------- +# parse_java_project_config — Maven properties (codeflash.*) +# --------------------------------------------------------------------------- + +MAVEN_POM_WITH_PROPERTIES = """\ + + 4.0.0 + com.example + test + 1.0 + + custom/src + custom/test + true + upstream + gen/,build/ + + +""" + + +class TestMavenCodeflashProperties: + def test_reads_custom_properties(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text(MAVEN_POM_WITH_PROPERTIES, encoding="utf-8") + (tmp_path / "custom" / "src").mkdir(parents=True) + (tmp_path / "custom" / "test").mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["module_root"] == str((tmp_path / "custom" / "src").resolve()) + assert config["tests_root"] == str((tmp_path / "custom" / "test").resolve()) + assert config["disable_telemetry"] is True + assert config["git_remote"] == "upstream" + assert len(config["ignore_paths"]) == 2 + + def test_properties_override_auto_detection(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text(MAVEN_POM_WITH_PROPERTIES, encoding="utf-8") + # Create standard dirs AND custom dirs + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + (tmp_path / "custom" / "src").mkdir(parents=True) + (tmp_path / "custom" / "test").mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + # Should use custom paths from properties, not auto-detected standard paths + assert config["module_root"] == str((tmp_path / "custom" / "src").resolve()) + + def test_no_properties_uses_defaults(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text( + '4.0.0', + encoding="utf-8", + ) + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["disable_telemetry"] is False + assert config["git_remote"] == "origin" + + +# --------------------------------------------------------------------------- +# parse_java_project_config — Gradle properties +# --------------------------------------------------------------------------- + + +class TestGradleCodeflashProperties: + def test_reads_gradle_properties(self, tmp_path: Path) -> None: + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + (tmp_path / "gradle.properties").write_text( + "codeflash.moduleRoot=lib/src\ncodeflash.testsRoot=lib/test\ncodeflash.disableTelemetry=true\n", + encoding="utf-8", + ) + (tmp_path / "lib" / "src").mkdir(parents=True) + (tmp_path / "lib" / "test").mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["module_root"] == str((tmp_path / "lib" / "src").resolve()) + assert config["tests_root"] == str((tmp_path / "lib" / "test").resolve()) + assert config["disable_telemetry"] is True + + def test_ignores_non_codeflash_properties(self, tmp_path: Path) -> None: + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + (tmp_path / "gradle.properties").write_text( + "org.gradle.jvmargs=-Xmx2g\ncodeflash.gitRemote=upstream\n", + encoding="utf-8", + ) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["git_remote"] == "upstream" + + def test_no_gradle_properties_uses_defaults(self, tmp_path: Path) -> None: + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + (tmp_path / "src" / "test" / "java").mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["git_remote"] == "origin" + assert config["disable_telemetry"] is False + + +# --------------------------------------------------------------------------- +# Multi-module Maven projects +# --------------------------------------------------------------------------- + +PARENT_POM = """\ + + 4.0.0 + com.example + parent + 1.0 + pom + + client + test + examples + + +""" + +CLIENT_POM = """\ + + 4.0.0 + + com.example + parent + 1.0 + + client + + ${project.basedir}/src + + +""" + +TEST_POM = """\ + + 4.0.0 + + com.example + parent + 1.0 + + test + + ${project.basedir}/src + + +""" + +EXAMPLES_POM = """\ + + 4.0.0 + + com.example + parent + 1.0 + + examples + + ${project.basedir}/src + + +""" + + +class TestMultiModuleMaven: + @pytest.fixture + def multi_module_project(self, tmp_path: Path) -> Path: + """Create a multi-module Maven project mimicking aerospike's layout.""" + (tmp_path / "pom.xml").write_text(PARENT_POM, encoding="utf-8") + + # Client module — main library with the most Java files + client = tmp_path / "client" + client.mkdir() + (client / "pom.xml").write_text(CLIENT_POM, encoding="utf-8") + client_src = client / "src" / "com" / "example" / "client" + client_src.mkdir(parents=True) + for i in range(10): + (client_src / f"Class{i}.java").write_text(f"class Class{i} {{}}", encoding="utf-8") + + # Test module — test code + test = tmp_path / "test" + test.mkdir() + (test / "pom.xml").write_text(TEST_POM, encoding="utf-8") + test_src = test / "src" / "com" / "example" / "test" + test_src.mkdir(parents=True) + (test_src / "ClientTest.java").write_text("class ClientTest {}", encoding="utf-8") + + # Examples module — should be skipped + examples = tmp_path / "examples" + examples.mkdir() + (examples / "pom.xml").write_text(EXAMPLES_POM, encoding="utf-8") + examples_src = examples / "src" / "com" / "example" + examples_src.mkdir(parents=True) + (examples_src / "Example.java").write_text("class Example {}", encoding="utf-8") + + return tmp_path + + def test_detects_client_as_source_root(self, multi_module_project: Path) -> None: + config = parse_java_project_config(multi_module_project) + assert config is not None + assert config["module_root"] == str(multi_module_project / "client" / "src") + + def test_detects_test_module_as_test_root(self, multi_module_project: Path) -> None: + config = parse_java_project_config(multi_module_project) + assert config is not None + assert config["tests_root"] == str(multi_module_project / "test" / "src") + + def test_skips_examples_module(self, multi_module_project: Path) -> None: + config = parse_java_project_config(multi_module_project) + assert config is not None + # The module_root should be client/src, not examples/src + assert config["module_root"] == str(multi_module_project / "client" / "src") + + def test_picks_module_with_most_java_files(self, multi_module_project: Path) -> None: + """Client has 10 .java files, examples has 1 — client should win.""" + config = parse_java_project_config(multi_module_project) + assert config is not None + assert "client" in config["module_root"] + + +# --------------------------------------------------------------------------- +# Language detection from config_parser +# --------------------------------------------------------------------------- + + +class TestLanguageDetectionViaConfigParser: + def test_java_detected_from_pom_xml(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + (tmp_path / "src" / "test" / "java").mkdir(parents=True) + monkeypatch.chdir(tmp_path) + + from codeflash.code_utils.config_parser import _try_parse_java_build_config + + result = _try_parse_java_build_config() + assert result is not None + config, project_root = result + assert config["language"] == "java" + assert project_root == tmp_path + + def test_java_detected_from_build_gradle(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + monkeypatch.chdir(tmp_path) + + from codeflash.code_utils.config_parser import _try_parse_java_build_config + + result = _try_parse_java_build_config() + assert result is not None + config, _ = result + assert config["language"] == "java" + + def test_no_java_detected_for_python_project(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + (tmp_path / "pyproject.toml").write_text("[tool.codeflash]\nmodule-root='src'\ntests-root='tests'\n", encoding="utf-8") + monkeypatch.chdir(tmp_path) + + from codeflash.code_utils.config_parser import _try_parse_java_build_config + + result = _try_parse_java_build_config() + assert result is None + + +# --------------------------------------------------------------------------- +# Language detection from tracer +# --------------------------------------------------------------------------- + + +class TestTracerLanguageDetection: + def test_detects_java_from_build_files(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + monkeypatch.chdir(tmp_path) + + from codeflash.languages.base import Language + from codeflash.tracer import _detect_non_python_language + + result = _detect_non_python_language(None) + assert result == Language.JAVA + + def test_no_detection_without_build_files(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.chdir(tmp_path) + + from codeflash.tracer import _detect_non_python_language + + result = _detect_non_python_language(None) + assert result is None + + def test_detects_java_from_file_extension(self, tmp_path: Path) -> None: + java_file = tmp_path / "App.java" + java_file.write_text("class App {}", encoding="utf-8") + + from argparse import Namespace + + from codeflash.languages.base import Language + from codeflash.tracer import _detect_non_python_language + + args = Namespace(file=str(java_file)) + result = _detect_non_python_language(args) + assert result == Language.JAVA diff --git a/tests/test_languages/test_java/test_jfr_parser.py b/tests/test_languages/test_java/test_jfr_parser.py new file mode 100644 index 000000000..8b5cf8a6e --- /dev/null +++ b/tests/test_languages/test_java/test_jfr_parser.py @@ -0,0 +1,302 @@ +"""Tests for JFR parser — class name normalization, package filtering, addressable time.""" + +from __future__ import annotations + +import json +import subprocess +from pathlib import Path +from unittest.mock import patch + +import pytest + +from codeflash.languages.java.jfr_parser import JfrProfile + + +def _make_jfr_json(events: list[dict]) -> str: + """Create fake JFR JSON output matching the jfr print format.""" + return json.dumps({"recording": {"events": events}}) + + +def _make_execution_sample(class_name: str, method_name: str, start_time: str = "2026-01-01T00:00:00Z") -> dict: + return { + "type": "jdk.ExecutionSample", + "values": { + "startTime": start_time, + "stackTrace": { + "frames": [ + { + "method": { + "type": {"name": class_name}, + "name": method_name, + "descriptor": "()V", + }, + "lineNumber": 42, + } + ], + }, + }, + } + + +class TestClassNameNormalization: + """Test that JVM internal class names (com/example/Foo) are normalized to dots (com.example.Foo).""" + + def test_slash_separators_normalized_to_dots(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/aerospike/client/command/Buffer", "bytesToInt"), + _make_execution_sample("com/aerospike/client/command/Buffer", "bytesToInt"), + _make_execution_sample("com/aerospike/client/util/Utf8", "encodedLength"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.aerospike"]) + + assert profile._total_samples == 3 + assert len(profile._method_samples) == 2 + + # Keys should use dots, not slashes + assert "com.aerospike.client.command.Buffer.bytesToInt" in profile._method_samples + assert "com.aerospike.client.util.Utf8.encodedLength" in profile._method_samples + + def test_method_info_uses_dot_class_names(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [_make_execution_sample("com/example/MyClass", "myMethod")] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + info = profile._method_info.get("com.example.MyClass.myMethod") + assert info is not None + assert info["class_name"] == "com.example.MyClass" + assert info["method_name"] == "myMethod" + + +class TestPackageFiltering: + def test_filters_by_package_prefix(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/aerospike/client/Value", "get"), + _make_execution_sample("java/util/HashMap", "put"), + _make_execution_sample("com/aerospike/benchmarks/Main", "main"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.aerospike"]) + + # Only com.aerospike classes should be in samples + assert len(profile._method_samples) == 2 + assert "com.aerospike.client.Value.get" in profile._method_samples + assert "com.aerospike.benchmarks.Main.main" in profile._method_samples + assert "java.util.HashMap.put" not in profile._method_samples + + def test_empty_packages_includes_all(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/example/Foo", "bar"), + _make_execution_sample("java/lang/String", "length"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, []) + + assert len(profile._method_samples) == 2 + + +class TestAddressableTime: + def test_addressable_time_proportional_to_samples(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + # 3 samples for methodA, 1 for methodB, spanning 10 seconds + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:00Z"), + _make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:03Z"), + _make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:06Z"), + _make_execution_sample("com/example/Foo", "methodB", "2026-01-01T00:00:10Z"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + time_a = profile.get_addressable_time_ns("com.example.Foo", "methodA") + time_b = profile.get_addressable_time_ns("com.example.Foo", "methodB") + + # methodA has 3x the samples of methodB, so 3x the addressable time + assert time_a > 0 + assert time_b > 0 + assert time_a == pytest.approx(time_b * 3, rel=0.01) + + def test_addressable_time_zero_for_unknown_method(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [_make_execution_sample("com/example/Foo", "bar")] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + assert profile.get_addressable_time_ns("com.example.Foo", "nonExistent") == 0.0 + + +class TestMethodRanking: + def test_ranking_ordered_by_sample_count(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/example/A", "hot"), + _make_execution_sample("com/example/A", "hot"), + _make_execution_sample("com/example/A", "hot"), + _make_execution_sample("com/example/B", "warm"), + _make_execution_sample("com/example/B", "warm"), + _make_execution_sample("com/example/C", "cold"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + ranking = profile.get_method_ranking() + assert len(ranking) == 3 + assert ranking[0]["method_name"] == "hot" + assert ranking[0]["sample_count"] == 3 + assert ranking[1]["method_name"] == "warm" + assert ranking[1]["sample_count"] == 2 + assert ranking[2]["method_name"] == "cold" + assert ranking[2]["sample_count"] == 1 + + def test_empty_ranking_when_no_samples(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json([]) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + assert profile.get_method_ranking() == [] + + def test_ranking_uses_dot_class_names(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [_make_execution_sample("com/example/nested/Deep", "method")] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + ranking = profile.get_method_ranking() + assert len(ranking) == 1 + assert ranking[0]["class_name"] == "com.example.nested.Deep" + + +class TestGracefulTimeout: + """Test that _run_java_with_graceful_timeout sends SIGTERM before SIGKILL.""" + + def test_sends_sigterm_on_timeout(self) -> None: + import signal + + from codeflash.languages.java.tracer import _run_java_with_graceful_timeout + + # Run a sleep command with a 1s timeout — should get SIGTERM'd + import os + + env = os.environ.copy() + _run_java_with_graceful_timeout(["sleep", "60"], env, timeout=1, stage_name="test") + # If we get here, the process was killed (didn't hang for 60s) + + def test_no_timeout_runs_normally(self) -> None: + import os + + from codeflash.languages.java.tracer import _run_java_with_graceful_timeout + + env = os.environ.copy() + _run_java_with_graceful_timeout(["echo", "hello"], env, timeout=0, stage_name="test") + # Should complete without error + + +class TestProjectRootResolution: + """Test that project_root is correctly set for Java multi-module projects.""" + + def test_java_project_root_is_build_root_not_module(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """For multi-module Maven, project_root should be the root with , not a sub-module.""" + # Create a multi-module project + (tmp_path / "pom.xml").write_text( + 'client', + encoding="utf-8", + ) + client = tmp_path / "client" + client.mkdir() + (client / "pom.xml").write_text("", encoding="utf-8") + src = client / "src" / "main" / "java" + src.mkdir(parents=True) + test = tmp_path / "src" / "test" / "java" + test.mkdir(parents=True) + monkeypatch.chdir(tmp_path) + + from codeflash.code_utils.config_parser import parse_config_file + + config, config_path = parse_config_file() + assert config["language"] == "java" + + # config_path should be the project root directory + assert config_path == tmp_path + + def test_project_root_is_path_not_string(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """project_root from process_pyproject_config should be a Path for Java projects.""" + from argparse import Namespace + + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + test = tmp_path / "src" / "test" / "java" + test.mkdir(parents=True) + monkeypatch.chdir(tmp_path) + + from codeflash.cli_cmds.cli import process_pyproject_config + + # Create a minimal args namespace matching what parse_args produces + args = Namespace( + config_file=None, module_root=None, tests_root=None, benchmarks_root=None, + ignore_paths=None, pytest_cmd=None, formatter_cmds=None, disable_telemetry=None, + disable_imports_sorting=None, git_remote=None, override_fixtures=None, + benchmark=False, verbose=False, version=False, show_config=False, reset_config=False, + ) + args = process_pyproject_config(args) + + assert hasattr(args, "project_root") + assert isinstance(args.project_root, Path) + assert args.project_root == tmp_path diff --git a/tests/test_languages/test_java/test_replay_test_generation.py b/tests/test_languages/test_java/test_replay_test_generation.py new file mode 100644 index 000000000..da7138114 --- /dev/null +++ b/tests/test_languages/test_java/test_replay_test_generation.py @@ -0,0 +1,255 @@ +"""Tests for Java replay test generation — JUnit 4/5 support, overload handling, instrumentation skip.""" + +from __future__ import annotations + +import sqlite3 +from pathlib import Path + +import pytest + +from codeflash.languages.java.replay_test import generate_replay_tests, parse_replay_test_metadata + + +@pytest.fixture +def trace_db(tmp_path: Path) -> Path: + """Create a trace database with sample function calls.""" + db_path = tmp_path / "trace.db" + conn = sqlite3.connect(str(db_path)) + conn.execute( + "CREATE TABLE function_calls(" + "type TEXT, function TEXT, classname TEXT, filename TEXT, " + "line_number INTEGER, descriptor TEXT, time_ns INTEGER, args BLOB)" + ) + conn.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)") + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ("call", "add", "com.example.Calculator", "Calculator.java", 10, "(II)I", 1000, b"\x00"), + ) + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ("call", "add", "com.example.Calculator", "Calculator.java", 10, "(II)I", 2000, b"\x00"), + ) + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ("call", "multiply", "com.example.Calculator", "Calculator.java", 20, "(II)I", 3000, b"\x00"), + ) + conn.commit() + conn.close() + return db_path + + +@pytest.fixture +def trace_db_overloaded(tmp_path: Path) -> Path: + """Create a trace database with overloaded methods (same name, different descriptors).""" + db_path = tmp_path / "trace_overloaded.db" + conn = sqlite3.connect(str(db_path)) + conn.execute( + "CREATE TABLE function_calls(" + "type TEXT, function TEXT, classname TEXT, filename TEXT, " + "line_number INTEGER, descriptor TEXT, time_ns INTEGER, args BLOB)" + ) + conn.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)") + # Two overloads of estimateKeySize with different descriptors + for i in range(3): + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ("call", "estimateKeySize", "com.example.Command", "Command.java", 10, "(I)I", i * 1000, b"\x00"), + ) + for i in range(2): + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ( + "call", + "estimateKeySize", + "com.example.Command", + "Command.java", + 15, + "(Ljava/lang/String;)I", + (i + 10) * 1000, + b"\x00", + ), + ) + conn.commit() + conn.close() + return db_path + + +class TestGenerateReplayTestsJunit5: + def test_generates_junit5_by_default(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + count = generate_replay_tests(trace_db, output_dir, tmp_path) + assert count == 1 + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "import org.junit.jupiter.api.Test;" in content + assert "import org.junit.jupiter.api.AfterAll;" in content + assert "@Test void replay_add_0()" in content + + def test_junit5_class_is_package_private(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path) + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "class ReplayTest_" in content + assert "public class ReplayTest_" not in content + + +class TestGenerateReplayTestsJunit4: + def test_generates_junit4_imports(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + count = generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4") + assert count == 1 + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "import org.junit.Test;" in content + assert "import org.junit.AfterClass;" in content + assert "org.junit.jupiter" not in content + + def test_junit4_methods_are_public(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4") + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "@Test public void replay_add_0()" in content + + def test_junit4_class_is_public(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4") + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "public class ReplayTest_" in content + + def test_junit4_cleanup_uses_afterclass(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4") + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "@AfterClass" in content + assert "@AfterAll" not in content + + +class TestOverloadedMethods: + def test_no_duplicate_method_names(self, trace_db_overloaded: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + count = generate_replay_tests(trace_db_overloaded, output_dir, tmp_path) + assert count == 1 + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + + # Should have 5 unique methods (3 from first overload + 2 from second) + assert "replay_estimateKeySize_0" in content + assert "replay_estimateKeySize_1" in content + assert "replay_estimateKeySize_2" in content + assert "replay_estimateKeySize_3" in content + assert "replay_estimateKeySize_4" in content + + # Verify no duplicates by counting occurrences + lines = content.splitlines() + method_lines = [l for l in lines if "void replay_estimateKeySize_" in l] + method_names = [l.split("void ")[1].split("(")[0] for l in method_lines] + assert len(method_names) == len(set(method_names)), f"Duplicate methods: {method_names}" + + +class TestReplayTestInstrumentation: + def test_replay_tests_instrumented_correctly(self, trace_db: Path, tmp_path: Path) -> None: + """Replay tests with compact @Test lines should be instrumented without orphaned code.""" + from codeflash.languages.java.discovery import discover_functions_from_source + + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path) + + test_file = list(output_dir.glob("*.java"))[0] + + src = "public class Calculator { public int add(int a, int b) { return a + b; } }" + funcs = discover_functions_from_source(src, tmp_path / "Calculator.java") + target = funcs[0] + + from codeflash.languages.java.support import JavaSupport + + support = JavaSupport() + success, instrumented = support.instrument_existing_test( + test_path=test_file, + call_positions=[], + function_to_optimize=target, + tests_project_root=tmp_path, + mode="behavior", + ) + assert success + assert instrumented is not None + assert "__perfinstrumented" in instrumented + + # Verify no code outside class body + lines = instrumented.splitlines() + class_closed = False + for line in lines: + if line.strip() == "}" and not line.startswith(" "): + class_closed = True + elif class_closed and line.strip() and not line.strip().startswith("//"): + pytest.fail(f"Orphaned code outside class: {line}") + + def test_replay_tests_perf_instrumented(self, trace_db: Path, tmp_path: Path) -> None: + from codeflash.languages.java.discovery import discover_functions_from_source + + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path) + + test_file = list(output_dir.glob("*.java"))[0] + + src = "public class Calculator { public int add(int a, int b) { return a + b; } }" + funcs = discover_functions_from_source(src, tmp_path / "Calculator.java") + target = funcs[0] + + from codeflash.languages.java.support import JavaSupport + + support = JavaSupport() + success, instrumented = support.instrument_existing_test( + test_path=test_file, + call_positions=[], + function_to_optimize=target, + tests_project_root=tmp_path, + mode="performance", + ) + assert success + assert "__perfonlyinstrumented" in instrumented + + def test_regular_tests_still_instrumented(self, tmp_path: Path) -> None: + from codeflash.languages.java.discovery import discover_functions_from_source + + src = "public class Calculator { public int add(int a, int b) { return a + b; } }" + funcs = discover_functions_from_source(src, tmp_path / "Calculator.java") + target = funcs[0] + + test_file = tmp_path / "CalculatorTest.java" + test_file.write_text( + """ +import org.junit.jupiter.api.Test; +public class CalculatorTest { + @Test + public void testAdd() { + Calculator calc = new Calculator(); + calc.add(1, 2); + } +} +""", + encoding="utf-8", + ) + + from codeflash.languages.java.support import JavaSupport + + support = JavaSupport() + success, instrumented = support.instrument_existing_test( + test_path=test_file, + call_positions=[], + function_to_optimize=target, + tests_project_root=tmp_path, + mode="behavior", + ) + assert success + assert "CODEFLASH_LOOP_INDEX" in instrumented