Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ cython_debug/
.aider*
/js/common/node_modules/
*.xml
# Allow pom.xml in test fixtures for Maven project detection
!tests/test_languages/fixtures/**/pom.xml
*.pem

# Ruff cache
Expand Down
108 changes: 101 additions & 7 deletions codeflash/languages/java/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,12 @@ def discover_tests(
function_map[func.name] = func
function_map[func.qualified_name] = func

# Find all test files
test_files = list(test_root.rglob("*Test.java")) + list(test_root.rglob("Test*.java"))
# Find all test files (various naming conventions)
test_files = (
list(test_root.rglob("*Test.java"))
+ list(test_root.rglob("*Tests.java"))
+ list(test_root.rglob("Test*.java"))
)

# Result map
result: dict[str, list[TestInfo]] = defaultdict(list)
Expand Down Expand Up @@ -134,11 +138,13 @@ def _match_test_to_functions(
matched.append(qualified)

# Strategy 3: Test class naming convention
# e.g., CalculatorTest tests Calculator
# e.g., CalculatorTest tests Calculator, TestCalculator tests Calculator
if test_method.class_name:
# Remove "Test" suffix or prefix
# Remove "Test/Tests" suffix or "Test" prefix
source_class_name = test_method.class_name
if source_class_name.endswith("Test"):
if source_class_name.endswith("Tests"):
source_class_name = source_class_name[:-5]
elif source_class_name.endswith("Test"):
source_class_name = source_class_name[:-4]
elif source_class_name.startswith("Test"):
source_class_name = source_class_name[4:]
Expand All @@ -149,9 +155,93 @@ def _match_test_to_functions(
if func_info.qualified_name not in matched:
matched.append(func_info.qualified_name)

# Strategy 4: Import-based matching
# If the test file imports a class containing the target function, consider it a match
# This handles cases like TestQueryBlob importing Buffer and calling Buffer methods
imported_classes = _extract_imports(tree.root_node, source_bytes, analyzer)

for func_name, func_info in function_map.items():
if func_info.qualified_name in matched:
continue

# Check if the function's class is imported
if func_info.class_name and func_info.class_name in imported_classes:
matched.append(func_info.qualified_name)

return matched


def _extract_imports(
node,
source_bytes: bytes,
analyzer: JavaAnalyzer,
) -> set[str]:
"""Extract imported class names from a Java file.

Args:
node: Tree-sitter root node.
source_bytes: Source code as bytes.
analyzer: JavaAnalyzer instance.

Returns:
Set of imported class names (simple names, not fully qualified).

"""
imports: set[str] = set()

def visit(n):
if n.type == "import_declaration":
import_text = analyzer.get_node_text(n, source_bytes)

# Check if it's a wildcard import - skip these as we can't know specific classes
if import_text.rstrip(";").endswith(".*"):
# For static wildcard imports like "import static com.example.Utils.*"
# we CAN extract the class name (Utils)
if "import static" in import_text:
# Extract class from "import static com.example.Utils.*"
# Remove "import static " prefix and ".*;" suffix
path = import_text.replace("import static ", "").rstrip(";").rstrip(".*")
if "." in path:
class_name = path.rsplit(".", 1)[-1]
if class_name and class_name[0].isupper(): # Ensure it's a class name
imports.add(class_name)
# For regular wildcards like "import com.example.*", skip entirely
return

# Check if it's a static import of a specific method/field
if "import static" in import_text:
# "import static com.example.Utils.format;"
# We want to extract "Utils" (the class), not "format" (the method)
path = import_text.replace("import static ", "").rstrip(";")
parts = path.rsplit(".", 2) # Split into [package..., Class, member]
if len(parts) >= 2:
# The second-to-last part is the class name
class_name = parts[-2]
if class_name and class_name[0].isupper(): # Ensure it's a class name
imports.add(class_name)
return

# Regular import: extract class name from scoped_identifier
for child in n.children:
if child.type == "scoped_identifier" or child.type == "identifier":
import_path = analyzer.get_node_text(child, source_bytes)
# Extract just the class name (last part)
# e.g., "com.example.Buffer" -> "Buffer"
if "." in import_path:
class_name = import_path.rsplit(".", 1)[-1]
else:
class_name = import_path
# Skip if it looks like a package name (lowercase)
if class_name and class_name[0].isupper():
imports.add(class_name)

for child in n.children:
visit(child)

visit(node)
return imports


def _find_method_calls_in_range(
node,
source_bytes: bytes,
Expand Down Expand Up @@ -260,8 +350,12 @@ def discover_all_tests(
analyzer = analyzer or get_java_analyzer()
all_tests: list[FunctionInfo] = []

# Find all test files
test_files = list(test_root.rglob("*Test.java")) + list(test_root.rglob("Test*.java"))
# Find all test files (various naming conventions)
test_files = (
list(test_root.rglob("*Test.java"))
+ list(test_root.rglob("*Tests.java"))
+ list(test_root.rglob("Test*.java"))
)

for test_file in test_files:
try:
Expand Down
52 changes: 52 additions & 0 deletions tests/test_languages/fixtures/java_maven/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>com.example</groupId>
<artifactId>codeflash-test-fixture</artifactId>
<version>1.0.0</version>
<packaging>jar</packaging>

<properties>
<maven.compiler.source>11</maven.compiler.source>
<maven.compiler.target>11</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<junit.jupiter.version>5.10.0</junit.jupiter.version>
</properties>

<dependencies>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>${junit.jupiter.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<version>${junit.jupiter.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.11.0</version>
<configuration>
<source>11</source>
<target>11</target>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.1.2</version>
</plugin>
</plugins>
</build>
</project>
Loading
Loading