diff --git a/.gitignore b/.gitignore index 99219de86..33c8cc162 100644 --- a/.gitignore +++ b/.gitignore @@ -164,6 +164,8 @@ cython_debug/ .aider* /js/common/node_modules/ *.xml +# Allow pom.xml in test fixtures for Maven project detection +!tests/test_languages/fixtures/**/pom.xml *.pem # Ruff cache diff --git a/codeflash/languages/java/test_discovery.py b/codeflash/languages/java/test_discovery.py index ee55bea30..fd27a2472 100644 --- a/codeflash/languages/java/test_discovery.py +++ b/codeflash/languages/java/test_discovery.py @@ -53,8 +53,12 @@ def discover_tests( function_map[func.name] = func function_map[func.qualified_name] = func - # Find all test files - test_files = list(test_root.rglob("*Test.java")) + list(test_root.rglob("Test*.java")) + # Find all test files (various naming conventions) + test_files = ( + list(test_root.rglob("*Test.java")) + + list(test_root.rglob("*Tests.java")) + + list(test_root.rglob("Test*.java")) + ) # Result map result: dict[str, list[TestInfo]] = defaultdict(list) @@ -134,11 +138,13 @@ def _match_test_to_functions( matched.append(qualified) # Strategy 3: Test class naming convention - # e.g., CalculatorTest tests Calculator + # e.g., CalculatorTest tests Calculator, TestCalculator tests Calculator if test_method.class_name: - # Remove "Test" suffix or prefix + # Remove "Test/Tests" suffix or "Test" prefix source_class_name = test_method.class_name - if source_class_name.endswith("Test"): + if source_class_name.endswith("Tests"): + source_class_name = source_class_name[:-5] + elif source_class_name.endswith("Test"): source_class_name = source_class_name[:-4] elif source_class_name.startswith("Test"): source_class_name = source_class_name[4:] @@ -149,9 +155,93 @@ def _match_test_to_functions( if func_info.qualified_name not in matched: matched.append(func_info.qualified_name) + # Strategy 4: Import-based matching + # If the test file imports a class containing the target function, consider it a match + # This handles cases like TestQueryBlob importing Buffer and calling Buffer methods + imported_classes = _extract_imports(tree.root_node, source_bytes, analyzer) + + for func_name, func_info in function_map.items(): + if func_info.qualified_name in matched: + continue + + # Check if the function's class is imported + if func_info.class_name and func_info.class_name in imported_classes: + matched.append(func_info.qualified_name) + return matched +def _extract_imports( + node, + source_bytes: bytes, + analyzer: JavaAnalyzer, +) -> set[str]: + """Extract imported class names from a Java file. + + Args: + node: Tree-sitter root node. + source_bytes: Source code as bytes. + analyzer: JavaAnalyzer instance. + + Returns: + Set of imported class names (simple names, not fully qualified). + + """ + imports: set[str] = set() + + def visit(n): + if n.type == "import_declaration": + import_text = analyzer.get_node_text(n, source_bytes) + + # Check if it's a wildcard import - skip these as we can't know specific classes + if import_text.rstrip(";").endswith(".*"): + # For static wildcard imports like "import static com.example.Utils.*" + # we CAN extract the class name (Utils) + if "import static" in import_text: + # Extract class from "import static com.example.Utils.*" + # Remove "import static " prefix and ".*;" suffix + path = import_text.replace("import static ", "").rstrip(";").rstrip(".*") + if "." in path: + class_name = path.rsplit(".", 1)[-1] + if class_name and class_name[0].isupper(): # Ensure it's a class name + imports.add(class_name) + # For regular wildcards like "import com.example.*", skip entirely + return + + # Check if it's a static import of a specific method/field + if "import static" in import_text: + # "import static com.example.Utils.format;" + # We want to extract "Utils" (the class), not "format" (the method) + path = import_text.replace("import static ", "").rstrip(";") + parts = path.rsplit(".", 2) # Split into [package..., Class, member] + if len(parts) >= 2: + # The second-to-last part is the class name + class_name = parts[-2] + if class_name and class_name[0].isupper(): # Ensure it's a class name + imports.add(class_name) + return + + # Regular import: extract class name from scoped_identifier + for child in n.children: + if child.type == "scoped_identifier" or child.type == "identifier": + import_path = analyzer.get_node_text(child, source_bytes) + # Extract just the class name (last part) + # e.g., "com.example.Buffer" -> "Buffer" + if "." in import_path: + class_name = import_path.rsplit(".", 1)[-1] + else: + class_name = import_path + # Skip if it looks like a package name (lowercase) + if class_name and class_name[0].isupper(): + imports.add(class_name) + + for child in n.children: + visit(child) + + visit(node) + return imports + + def _find_method_calls_in_range( node, source_bytes: bytes, @@ -260,8 +350,12 @@ def discover_all_tests( analyzer = analyzer or get_java_analyzer() all_tests: list[FunctionInfo] = [] - # Find all test files - test_files = list(test_root.rglob("*Test.java")) + list(test_root.rglob("Test*.java")) + # Find all test files (various naming conventions) + test_files = ( + list(test_root.rglob("*Test.java")) + + list(test_root.rglob("*Tests.java")) + + list(test_root.rglob("Test*.java")) + ) for test_file in test_files: try: diff --git a/tests/test_languages/fixtures/java_maven/pom.xml b/tests/test_languages/fixtures/java_maven/pom.xml new file mode 100644 index 000000000..bd4dc42e8 --- /dev/null +++ b/tests/test_languages/fixtures/java_maven/pom.xml @@ -0,0 +1,52 @@ + + + 4.0.0 + + com.example + codeflash-test-fixture + 1.0.0 + jar + + + 11 + 11 + UTF-8 + 5.10.0 + + + + + org.junit.jupiter + junit-jupiter + ${junit.jupiter.version} + test + + + org.junit.jupiter + junit-jupiter-params + ${junit.jupiter.version} + test + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.11.0 + + 11 + 11 + + + + org.apache.maven.plugins + maven-surefire-plugin + 3.1.2 + + + + diff --git a/tests/test_languages/test_java/test_test_discovery.py b/tests/test_languages/test_java/test_test_discovery.py index a0aa5972b..49418516c 100644 --- a/tests/test_languages/test_java/test_test_discovery.py +++ b/tests/test_languages/test_java/test_test_discovery.py @@ -185,6 +185,120 @@ def test_find_tests(self, tmp_path: Path): assert "testReverse" in test_names or len(tests) >= 0 +class TestImportBasedDiscovery: + """Tests for import-based test discovery.""" + + def test_discover_by_import_when_class_name_doesnt_match(self, tmp_path: Path): + """Test that tests are discovered when they import a class even if class name doesn't match. + + This reproduces a real-world scenario from aerospike-client-java where: + - TestQueryBlob imports Buffer class + - TestQueryBlob calls Buffer.longToBytes() directly + - We want to optimize Buffer.bytesToHexString() + - The test should be discovered because it imports and uses Buffer + """ + # Create source file with utility methods + src_dir = tmp_path / "src" / "main" / "java" / "com" / "example" + src_dir.mkdir(parents=True) + src_file = src_dir / "Buffer.java" + src_file.write_text(""" +package com.example; + +public class Buffer { + public static String bytesToHexString(byte[] buf) { + StringBuilder sb = new StringBuilder(); + for (byte b : buf) { + sb.append(String.format("%02x", b)); + } + return sb.toString(); + } + + public static void longToBytes(long v, byte[] buf, int offset) { + buf[offset] = (byte)(v >> 56); + buf[offset+1] = (byte)(v >> 48); + } +} +""") + + # Create test file that imports Buffer but has non-matching name + test_dir = tmp_path / "src" / "test" / "java" / "com" / "example" + test_dir.mkdir(parents=True) + test_file = test_dir / "TestQueryBlob.java" + test_file.write_text(""" +package com.example; + +import org.junit.jupiter.api.Test; +import com.example.Buffer; + +public class TestQueryBlob { + @Test + public void queryBlob() { + byte[] bytes = new byte[8]; + Buffer.longToBytes(50003, bytes, 0); + // Uses Buffer class + } +} +""") + + # Get source functions + source_functions = discover_functions_from_source( + src_file.read_text(), file_path=src_file + ) + + # Filter to just bytesToHexString + target_functions = [f for f in source_functions if f.name == "bytesToHexString"] + assert len(target_functions) == 1, "Should find bytesToHexString function" + + # Discover tests + result = discover_tests(tmp_path / "src" / "test" / "java", target_functions) + + # The test should be discovered because it imports Buffer class + # Even though TestQueryBlob doesn't follow naming convention for BufferTest + assert len(result) > 0, "Should find tests that import the target class" + assert "Buffer.bytesToHexString" in result, f"Should map test to Buffer.bytesToHexString, got: {result.keys()}" + + def test_discover_by_direct_method_call(self, tmp_path: Path): + """Test that tests are discovered when they directly call the target method.""" + # Create source file + src_dir = tmp_path / "src" / "main" / "java" + src_dir.mkdir(parents=True) + src_file = src_dir / "Utils.java" + src_file.write_text(""" +public class Utils { + public static String format(String s) { + return s.toUpperCase(); + } +} +""") + + # Create test with direct call to format() + test_dir = tmp_path / "src" / "test" / "java" + test_dir.mkdir(parents=True) + test_file = test_dir / "IntegrationTest.java" + test_file.write_text(""" +import org.junit.jupiter.api.Test; + +public class IntegrationTest { + @Test + public void testFormatting() { + String result = Utils.format("hello"); + assertEquals("HELLO", result); + } +} +""") + + # Get source functions + source_functions = discover_functions_from_source( + src_file.read_text(), file_path=src_file + ) + + # Discover tests + result = discover_tests(test_dir, source_functions) + + # Should find the test that calls format() + assert len(result) > 0, "Should find tests that directly call target method" + + class TestWithFixture: """Tests using the Java fixture project.""" @@ -204,3 +318,240 @@ def test_discover_fixture_tests(self, java_fixture_path: Path): tests = discover_all_tests(test_root) assert len(tests) > 0 + + +class TestImportExtraction: + """Tests for the _extract_imports helper function.""" + + def test_basic_import(self): + """Test extraction of basic import statement.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import com.example.Calculator; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + assert imports == {"Calculator"} + + def test_multiple_imports(self): + """Test extraction of multiple imports.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import com.example.util.Helper; +import com.example.Calculator; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + assert imports == {"Helper", "Calculator"} + + def test_wildcard_import_returns_empty(self): + """Test that wildcard imports don't add specific classes.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import com.example.*; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + assert imports == set() + + def test_static_import_extracts_class(self): + """Test that static imports extract the class name, not the method.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import static com.example.Utils.format; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + assert imports == {"Utils"} + + def test_static_wildcard_import_extracts_class(self): + """Test that static wildcard imports extract the class name.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import static com.example.Utils.*; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + assert imports == {"Utils"} + + def test_deeply_nested_package(self): + """Test extraction from deeply nested package.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import com.aerospike.client.command.Buffer; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + assert imports == {"Buffer"} + + def test_mixed_imports(self): + """Test extraction with mix of regular, static, and wildcard imports.""" + from codeflash.languages.java.test_discovery import _extract_imports + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +import com.example.Calculator; +import com.example.util.*; +import static org.junit.Assert.assertEquals; +import static com.example.Utils.*; +public class Test {} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + imports = _extract_imports(tree.root_node, source_bytes, analyzer) + + # Should have Calculator, Assert, Utils but NOT wildcards + assert "Calculator" in imports + assert "Assert" in imports + assert "Utils" in imports + + +class TestMethodCallDetection: + """Tests for method call detection in test code.""" + + def test_find_method_calls(self): + """Test detection of method calls within a code range.""" + from codeflash.languages.java.test_discovery import _find_method_calls_in_range + from codeflash.languages.java.parser import get_java_analyzer + + analyzer = get_java_analyzer() + source = """ +public class TestExample { + @Test + public void testSomething() { + Calculator calc = new Calculator(); + int result = calc.add(2, 3); + String hex = Buffer.bytesToHexString(data); + helper.process(x); + } +} +""" + source_bytes = source.encode("utf8") + tree = analyzer.parse(source_bytes) + calls = _find_method_calls_in_range(tree.root_node, source_bytes, 1, 10, analyzer) + + assert "add" in calls + assert "bytesToHexString" in calls + assert "process" in calls + + +class TestClassNamingConventions: + """Tests for class naming convention matching.""" + + def test_suffix_test_pattern(self, tmp_path: Path): + """Test that ClassNameTest matches ClassName.""" + src_file = tmp_path / "Calculator.java" + src_file.write_text(""" +public class Calculator { + public int add(int a, int b) { return a + b; } +} +""") + + test_dir = tmp_path / "test" + test_dir.mkdir() + test_file = test_dir / "CalculatorTest.java" + test_file.write_text(""" +import org.junit.jupiter.api.Test; +public class CalculatorTest { + @Test + public void testAdd() { } +} +""") + + source_functions = discover_functions_from_source(src_file.read_text(), src_file) + result = discover_tests(test_dir, source_functions) + + # CalculatorTest should match Calculator class + assert len(result) > 0 + assert "Calculator.add" in result + + def test_prefix_test_pattern(self, tmp_path: Path): + """Test that TestClassName matches ClassName.""" + src_file = tmp_path / "Calculator.java" + src_file.write_text(""" +public class Calculator { + public int add(int a, int b) { return a + b; } +} +""") + + test_dir = tmp_path / "test" + test_dir.mkdir() + test_file = test_dir / "TestCalculator.java" + test_file.write_text(""" +import org.junit.jupiter.api.Test; +public class TestCalculator { + @Test + public void testAdd() { } +} +""") + + source_functions = discover_functions_from_source(src_file.read_text(), src_file) + result = discover_tests(test_dir, source_functions) + + # TestCalculator should match Calculator class + assert len(result) > 0 + assert "Calculator.add" in result + + def test_tests_suffix_pattern(self, tmp_path: Path): + """Test that ClassNameTests matches ClassName.""" + src_file = tmp_path / "Calculator.java" + src_file.write_text(""" +public class Calculator { + public int add(int a, int b) { return a + b; } +} +""") + + test_dir = tmp_path / "test" + test_dir.mkdir() + test_file = test_dir / "CalculatorTests.java" + test_file.write_text(""" +import org.junit.jupiter.api.Test; +public class CalculatorTests { + @Test + public void testAdd() { } +} +""") + + source_functions = discover_functions_from_source(src_file.read_text(), src_file) + result = discover_tests(test_dir, source_functions) + + # CalculatorTests should match Calculator class + assert len(result) > 0 + assert "Calculator.add" in result