From 47eef86b37e15f50503483094f8e25eac8ce7e2e Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Mon, 2 Feb 2026 23:20:15 +0000 Subject: [PATCH 1/3] feat: add import-based test discovery for Java Add Strategy 4 to Java test discovery: import-based matching. When a test file imports a class containing the target function, consider it a potential test for that function. This fixes an issue where tests like TestQueryBlob (which imports and uses Buffer) were not being discovered as tests for Buffer methods because the class naming convention didn't match. Includes test cases that reproduce the real-world scenario from aerospike-client-java where test class names don't follow the standard naming pattern. Co-Authored-By: Claude Opus 4.5 --- codeflash/languages/java/test_discovery.py | 54 +++++++++ .../test_java/test_test_discovery.py | 114 ++++++++++++++++++ 2 files changed, 168 insertions(+) diff --git a/codeflash/languages/java/test_discovery.py b/codeflash/languages/java/test_discovery.py index ee55bea30..497c60b37 100644 --- a/codeflash/languages/java/test_discovery.py +++ b/codeflash/languages/java/test_discovery.py @@ -149,9 +149,63 @@ def _match_test_to_functions( if func_info.qualified_name not in matched: matched.append(func_info.qualified_name) + # Strategy 4: Import-based matching + # If the test file imports a class containing the target function, consider it a match + # This handles cases like TestQueryBlob importing Buffer and calling Buffer methods + imported_classes = _extract_imports(tree.root_node, source_bytes, analyzer) + + for func_name, func_info in function_map.items(): + if func_info.qualified_name in matched: + continue + + # Check if the function's class is imported + if func_info.class_name and func_info.class_name in imported_classes: + matched.append(func_info.qualified_name) + return matched +def _extract_imports( + node, + source_bytes: bytes, + analyzer: JavaAnalyzer, +) -> set[str]: + """Extract imported class names from a Java file. + + Args: + node: Tree-sitter root node. + source_bytes: Source code as bytes. + analyzer: JavaAnalyzer instance. + + Returns: + Set of imported class names (simple names, not fully qualified). + + """ + imports: set[str] = set() + + def visit(n): + if n.type == "import_declaration": + # Get the full import path + for child in n.children: + if child.type == "scoped_identifier" or child.type == "identifier": + import_path = analyzer.get_node_text(child, source_bytes) + # Extract just the class name (last part) + # e.g., "com.example.Buffer" -> "Buffer" + if "." in import_path: + class_name = import_path.rsplit(".", 1)[-1] + else: + class_name = import_path + # Skip wildcard imports (*) + if class_name != "*": + imports.add(class_name) + + for child in n.children: + visit(child) + + visit(node) + return imports + + def _find_method_calls_in_range( node, source_bytes: bytes, diff --git a/tests/test_languages/test_java/test_test_discovery.py b/tests/test_languages/test_java/test_test_discovery.py index a0aa5972b..684e9912f 100644 --- a/tests/test_languages/test_java/test_test_discovery.py +++ b/tests/test_languages/test_java/test_test_discovery.py @@ -185,6 +185,120 @@ def test_find_tests(self, tmp_path: Path): assert "testReverse" in test_names or len(tests) >= 0 +class TestImportBasedDiscovery: + """Tests for import-based test discovery.""" + + def test_discover_by_import_when_class_name_doesnt_match(self, tmp_path: Path): + """Test that tests are discovered when they import a class even if class name doesn't match. + + This reproduces a real-world scenario from aerospike-client-java where: + - TestQueryBlob imports Buffer class + - TestQueryBlob calls Buffer.longToBytes() directly + - We want to optimize Buffer.bytesToHexString() + - The test should be discovered because it imports and uses Buffer + """ + # Create source file with utility methods + src_dir = tmp_path / "src" / "main" / "java" / "com" / "example" + src_dir.mkdir(parents=True) + src_file = src_dir / "Buffer.java" + src_file.write_text(""" +package com.example; + +public class Buffer { + public static String bytesToHexString(byte[] buf) { + StringBuilder sb = new StringBuilder(); + for (byte b : buf) { + sb.append(String.format("%02x", b)); + } + return sb.toString(); + } + + public static void longToBytes(long v, byte[] buf, int offset) { + buf[offset] = (byte)(v >> 56); + buf[offset+1] = (byte)(v >> 48); + } +} +""") + + # Create test file that imports Buffer but has non-matching name + test_dir = tmp_path / "src" / "test" / "java" / "com" / "example" + test_dir.mkdir(parents=True) + test_file = test_dir / "TestQueryBlob.java" + test_file.write_text(""" +package com.example; + +import org.junit.jupiter.api.Test; +import com.example.Buffer; + +public class TestQueryBlob { + @Test + public void queryBlob() { + byte[] bytes = new byte[8]; + Buffer.longToBytes(50003, bytes, 0); + // Uses Buffer class + } +} +""") + + # Get source functions + source_functions = discover_functions_from_source( + src_file.read_text(), file_path=src_file + ) + + # Filter to just bytesToHexString + target_functions = [f for f in source_functions if f.name == "bytesToHexString"] + assert len(target_functions) == 1, "Should find bytesToHexString function" + + # Discover tests + result = discover_tests(tmp_path / "src" / "test" / "java", target_functions) + + # The test should be discovered because it imports Buffer class + # Even though TestQueryBlob doesn't follow naming convention for BufferTest + assert len(result) > 0, "Should find tests that import the target class" + assert "Buffer.bytesToHexString" in result, f"Should map test to Buffer.bytesToHexString, got: {result.keys()}" + + def test_discover_by_direct_method_call(self, tmp_path: Path): + """Test that tests are discovered when they directly call the target method.""" + # Create source file + src_dir = tmp_path / "src" / "main" / "java" + src_dir.mkdir(parents=True) + src_file = src_dir / "Utils.java" + src_file.write_text(""" +public class Utils { + public static String format(String s) { + return s.toUpperCase(); + } +} +""") + + # Create test with direct call to format() + test_dir = tmp_path / "src" / "test" / "java" + test_dir.mkdir(parents=True) + test_file = test_dir / "IntegrationTest.java" + test_file.write_text(""" +import org.junit.jupiter.api.Test; + +public class IntegrationTest { + @Test + public void testFormatting() { + String result = Utils.format("hello"); + assertEquals("HELLO", result); + } +} +""") + + # Get source functions + source_functions = discover_functions_from_source( + src_file.read_text(), file_path=src_file + ) + + # Discover tests + result = discover_tests(test_dir, source_functions) + + # Should find the test that calls format() + assert len(result) > 0, "Should find tests that directly call target method" + + class TestWithFixture: """Tests using the Java fixture project.""" From dc52f4ddb32f34fd2f691a898cb3984de5a29f47 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Mon, 2 Feb 2026 23:36:50 +0000 Subject: [PATCH 2/3] fix: comprehensive improvements to Java test discovery This commit adds thorough testing and fixes several bugs discovered by running test discovery against real-world examples from aerospike-client-java. Bugs fixed: 1. Import extraction for wildcard imports (import com.example.*) was incorrectly extracting "example" as a class name 2. Static imports (import static Utils.format) were extracting the method name instead of the class name 3. *Tests.java files (plural) were not being discovered as test files 4. ClassNameTests pattern wasn't handled in naming convention matching New test cases added: - TestImportExtraction: 7 tests for import statement parsing - Basic imports, multiple imports, wildcard imports - Static imports, static wildcard imports, deeply nested packages - Mixed import scenarios - TestMethodCallDetection: tests for method call detection in tests - TestClassNamingConventions: 3 tests for naming patterns - *Test, Test*, *Tests suffix/prefix patterns All tests verified against real aerospike-client-java test files: - TestQueryBlob correctly imports Buffer class - TestPutGet correctly imports Assert, Bin, Key, etc. - TestAsyncBatch correctly imports batch operation classes Co-Authored-By: Claude Opus 4.5 --- codeflash/languages/java/test_discovery.py | 60 ++++- .../test_java/test_test_discovery.py | 237 ++++++++++++++++++ 2 files changed, 287 insertions(+), 10 deletions(-) diff --git a/codeflash/languages/java/test_discovery.py b/codeflash/languages/java/test_discovery.py index 497c60b37..fd27a2472 100644 --- a/codeflash/languages/java/test_discovery.py +++ b/codeflash/languages/java/test_discovery.py @@ -53,8 +53,12 @@ def discover_tests( function_map[func.name] = func function_map[func.qualified_name] = func - # Find all test files - test_files = list(test_root.rglob("*Test.java")) + list(test_root.rglob("Test*.java")) + # Find all test files (various naming conventions) + test_files = ( + list(test_root.rglob("*Test.java")) + + list(test_root.rglob("*Tests.java")) + + list(test_root.rglob("Test*.java")) + ) # Result map result: dict[str, list[TestInfo]] = defaultdict(list) @@ -134,11 +138,13 @@ def _match_test_to_functions( matched.append(qualified) # Strategy 3: Test class naming convention - # e.g., CalculatorTest tests Calculator + # e.g., CalculatorTest tests Calculator, TestCalculator tests Calculator if test_method.class_name: - # Remove "Test" suffix or prefix + # Remove "Test/Tests" suffix or "Test" prefix source_class_name = test_method.class_name - if source_class_name.endswith("Test"): + if source_class_name.endswith("Tests"): + source_class_name = source_class_name[:-5] + elif source_class_name.endswith("Test"): source_class_name = source_class_name[:-4] elif source_class_name.startswith("Test"): source_class_name = source_class_name[4:] @@ -185,7 +191,37 @@ def _extract_imports( def visit(n): if n.type == "import_declaration": - # Get the full import path + import_text = analyzer.get_node_text(n, source_bytes) + + # Check if it's a wildcard import - skip these as we can't know specific classes + if import_text.rstrip(";").endswith(".*"): + # For static wildcard imports like "import static com.example.Utils.*" + # we CAN extract the class name (Utils) + if "import static" in import_text: + # Extract class from "import static com.example.Utils.*" + # Remove "import static " prefix and ".*;" suffix + path = import_text.replace("import static ", "").rstrip(";").rstrip(".*") + if "." in path: + class_name = path.rsplit(".", 1)[-1] + if class_name and class_name[0].isupper(): # Ensure it's a class name + imports.add(class_name) + # For regular wildcards like "import com.example.*", skip entirely + return + + # Check if it's a static import of a specific method/field + if "import static" in import_text: + # "import static com.example.Utils.format;" + # We want to extract "Utils" (the class), not "format" (the method) + path = import_text.replace("import static ", "").rstrip(";") + parts = path.rsplit(".", 2) # Split into [package..., Class, member] + if len(parts) >= 2: + # The second-to-last part is the class name + class_name = parts[-2] + if class_name and class_name[0].isupper(): # Ensure it's a class name + imports.add(class_name) + return + + # Regular import: extract class name from scoped_identifier for child in n.children: if child.type == "scoped_identifier" or child.type == "identifier": import_path = analyzer.get_node_text(child, source_bytes) @@ -195,8 +231,8 @@ def visit(n): class_name = import_path.rsplit(".", 1)[-1] else: class_name = import_path - # Skip wildcard imports (*) - if class_name != "*": + # Skip if it looks like a package name (lowercase) + if class_name and class_name[0].isupper(): imports.add(class_name) for child in n.children: @@ -314,8 +350,12 @@ def discover_all_tests( analyzer = analyzer or get_java_analyzer() all_tests: list[FunctionInfo] = [] - # Find all test files - test_files = list(test_root.rglob("*Test.java")) + list(test_root.rglob("Test*.java")) + # Find all test files (various naming conventions) + test_files = ( + list(test_root.rglob("*Test.java")) + + list(test_root.rglob("*Tests.java")) + + list(test_root.rglob("Test*.java")) + ) for test_file in test_files: try: diff --git a/tests/test_languages/test_java/test_test_discovery.py b/tests/test_languages/test_java/test_test_discovery.py index 684e9912f..49418516c 100644 --- a/tests/test_languages/test_java/test_test_discovery.py +++ b/tests/test_languages/test_java/test_test_discovery.py @@ -318,3 +318,240 @@ def test_discover_fixture_tests(self, java_fixture_path: Path): tests = discover_all_tests(test_root) assert len(tests) > 0 + + +class TestImportExtraction: + """Tests for the _extract_imports helper function.""" + + def test_basic_import(self): + """Test extraction of basic import statement.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import com.example.Calculator; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + assert imports == {"Calculator"} + + def test_multiple_imports(self): + """Test extraction of multiple imports.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import com.example.util.Helper; +import com.example.Calculator; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + assert imports == {"Helper", "Calculator"} + + def test_wildcard_import_returns_empty(self): + """Test that wildcard imports don't add specific classes.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import com.example.*; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + assert imports == set() + + def test_static_import_extracts_class(self): + """Test that static imports extract the class name, not the method.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import static com.example.Utils.format; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + assert imports == {"Utils"} + + def test_static_wildcard_import_extracts_class(self): + """Test that static wildcard imports extract the class name.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import static com.example.Utils.*; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + assert imports == {"Utils"} + + def test_deeply_nested_package(self): + """Test extraction from deeply nested package.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import com.aerospike.client.command.Buffer; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + assert imports == {"Buffer"} + + def test_mixed_imports(self): + """Test extraction with mix of regular, static, and wildcard imports.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import com.example.Calculator; +import com.example.util.*; +import static org.junit.Assert.assertEquals; +import static com.example.Utils.*; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + # Should have Calculator, Assert, Utils but NOT wildcards + assert "Calculator" in imports + assert "Assert" in imports + assert "Utils" in imports + + +class TestMethodCallDetection: + """Tests for method call detection in test code.""" + + def test_find_method_calls(self): + """Test detection of method calls within a code range.""" + from codeflash.languages.java.test_discovery import _find_method_calls_in_range + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +public class TestExample { + @Test + public void testSomething() { + Calculator calc = new Calculator(); + int result = calc.add(2, 3); + String hex = Buffer.bytesToHexString(data); + helper.process(x); + } +} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + calls = _find_method_calls_in_range(tree.root_node, source_bytes, 1, 10, analyzer) + + assert "add" in calls + assert "bytesToHexString" in calls + assert "process" in calls + + +class TestClassNamingConventions: + """Tests for class naming convention matching.""" + + def test_suffix_test_pattern(self, tmp_path: Path): + """Test that ClassNameTest matches ClassName.""" + src_file = tmp_path / "Calculator.java" + src_file.write_text(""" +public class Calculator { + public int add(int a, int b) { return a + b; } +} +""") + + test_dir = tmp_path / "test" + test_dir.mkdir() + test_file = test_dir / "CalculatorTest.java" + test_file.write_text(""" +import org.junit.jupiter.api.Test; +public class CalculatorTest { + @Test + public void testAdd() { } +} +""") + + source_functions = discover_functions_from_source(src_file.read_text(), src_file) + result = discover_tests(test_dir, source_functions) + + # CalculatorTest should match Calculator class + assert len(result) > 0 + assert "Calculator.add" in result + + def test_prefix_test_pattern(self, tmp_path: Path): + """Test that TestClassName matches ClassName.""" + src_file = tmp_path / "Calculator.java" + src_file.write_text(""" +public class Calculator { + public int add(int a, int b) { return a + b; } +} +""") + + test_dir = tmp_path / "test" + test_dir.mkdir() + test_file = test_dir / "TestCalculator.java" + test_file.write_text(""" +import org.junit.jupiter.api.Test; +public class TestCalculator { + @Test + public void testAdd() { } +} +""") + + source_functions = discover_functions_from_source(src_file.read_text(), src_file) + result = discover_tests(test_dir, source_functions) + + # TestCalculator should match Calculator class + assert len(result) > 0 + assert "Calculator.add" in result + + def test_tests_suffix_pattern(self, tmp_path: Path): + """Test that ClassNameTests matches ClassName.""" + src_file = tmp_path / "Calculator.java" + src_file.write_text(""" +public class Calculator { + public int add(int a, int b) { return a + b; } +} +""") + + test_dir = tmp_path / "test" + test_dir.mkdir() + test_file = test_dir / "CalculatorTests.java" + test_file.write_text(""" +import org.junit.jupiter.api.Test; +public class CalculatorTests { + @Test + public void testAdd() { } +} +""") + + source_functions = discover_functions_from_source(src_file.read_text(), src_file) + result = discover_tests(test_dir, source_functions) + + # CalculatorTests should match Calculator class + assert len(result) > 0 + assert "Calculator.add" in result From 5c0a9e7b03a301270929d88c1a8b400cb89c5f0d Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Mon, 2 Feb 2026 23:53:42 +0000 Subject: [PATCH 3/3] fix: add pom.xml to java_maven test fixture The test_detect_fixture_project test expects the java_maven fixture directory to have a pom.xml file for Maven build tool detection. Add the missing pom.xml with JUnit 5 dependencies. Also add .gitignore exception to allow pom.xml files in test fixtures. Co-Authored-By: Claude Opus 4.5 --- .gitignore | 2 + .../fixtures/java_maven/pom.xml | 52 +++++++++++++++++++ 2 files changed, 54 insertions(+) create mode 100644 tests/test_languages/fixtures/java_maven/pom.xml diff --git a/.gitignore b/.gitignore index 99219de86..33c8cc162 100644 --- a/.gitignore +++ b/.gitignore @@ -164,6 +164,8 @@ cython_debug/ .aider* /js/common/node_modules/ *.xml +# Allow pom.xml in test fixtures for Maven project detection +!tests/test_languages/fixtures/**/pom.xml *.pem # Ruff cache diff --git a/tests/test_languages/fixtures/java_maven/pom.xml b/tests/test_languages/fixtures/java_maven/pom.xml new file mode 100644 index 000000000..bd4dc42e8 --- /dev/null +++ b/tests/test_languages/fixtures/java_maven/pom.xml @@ -0,0 +1,52 @@ + + + 4.0.0 + + com.example + codeflash-test-fixture + 1.0.0 + jar + + + 11 + 11 + UTF-8 + 5.10.0 + + + + + org.junit.jupiter + junit-jupiter + ${junit.jupiter.version} + test + + + org.junit.jupiter + junit-jupiter-params + ${junit.jupiter.version} + test + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.11.0 + + 11 + 11 + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.1.2 + + + +