Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions codeflash/languages/java/build_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,31 @@
logger = logging.getLogger(__name__)


def _safe_parse_xml(file_path: Path) -> ET.ElementTree:
"""Safely parse an XML file with protections against XXE attacks.

Args:
file_path: Path to the XML file.

Returns:
Parsed ElementTree.

Raises:
ET.ParseError: If XML parsing fails.
"""
# Read file content and parse as string to avoid file-based attacks
# This prevents XXE attacks by not allowing external entity resolution
content = file_path.read_text(encoding="utf-8")

# Parse string content (no external entities possible)
root = ET.fromstring(content)

# Create ElementTree from root
tree = ET.ElementTree(root)

return tree


class BuildTool(Enum):
"""Supported Java build tools."""

Expand Down Expand Up @@ -124,7 +149,7 @@ def _get_maven_project_info(project_root: Path) -> JavaProjectInfo | None:
return None

try:
tree = ET.parse(pom_path)
tree = _safe_parse_xml(pom_path)
root = tree.getroot()

# Handle Maven namespace
Expand Down Expand Up @@ -438,16 +463,34 @@ def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]:

for xml_file in surefire_dir.glob("TEST-*.xml"):
try:
tree = ET.parse(xml_file)
tree = _safe_parse_xml(xml_file)
root = tree.getroot()

tests_run += int(root.get("tests", 0))
failures += int(root.get("failures", 0))
errors += int(root.get("errors", 0))
skipped += int(root.get("skipped", 0))
# Safely parse numeric attributes with validation
try:
tests_run += int(root.get("tests", "0"))
except (ValueError, TypeError):
logger.warning("Invalid 'tests' value in %s, defaulting to 0", xml_file)

try:
failures += int(root.get("failures", "0"))
except (ValueError, TypeError):
logger.warning("Invalid 'failures' value in %s, defaulting to 0", xml_file)

try:
errors += int(root.get("errors", "0"))
except (ValueError, TypeError):
logger.warning("Invalid 'errors' value in %s, defaulting to 0", xml_file)

try:
skipped += int(root.get("skipped", "0"))
except (ValueError, TypeError):
logger.warning("Invalid 'skipped' value in %s, defaulting to 0", xml_file)

except ET.ParseError as e:
logger.warning("Failed to parse Surefire report %s: %s", xml_file, e)
except Exception as e:
logger.warning("Unexpected error parsing Surefire report %s: %s", xml_file, e)

return tests_run, failures, errors, skipped

Expand Down Expand Up @@ -572,7 +615,7 @@ def add_codeflash_dependency_to_pom(pom_path: Path) -> bool:
return False

try:
tree = ET.parse(pom_path)
tree = _safe_parse_xml(pom_path)
root = tree.getroot()

# Handle Maven namespace
Expand Down Expand Up @@ -647,7 +690,7 @@ def is_jacoco_configured(pom_path: Path) -> bool:
return False

try:
tree = ET.parse(pom_path)
tree = _safe_parse_xml(pom_path)
root = tree.getroot()

# Handle Maven namespace
Expand Down
66 changes: 64 additions & 2 deletions codeflash/languages/java/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
import os
import re
import shutil
import subprocess
import tempfile
Expand All @@ -28,6 +29,55 @@

logger = logging.getLogger(__name__)

# Regex pattern for valid Java class names (package.ClassName format)
# Allows: letters, digits, underscores, dots, and dollar signs (inner classes)
_VALID_JAVA_CLASS_NAME = re.compile(r'^[a-zA-Z_$][a-zA-Z0-9_$.]*$')


def _validate_java_class_name(class_name: str) -> bool:
"""Validate that a string is a valid Java class name.

This prevents command injection when passing test class names to Maven.

Args:
class_name: The class name to validate (e.g., "com.example.MyTest").

Returns:
True if valid, False otherwise.
"""
return bool(_VALID_JAVA_CLASS_NAME.match(class_name))


def _validate_test_filter(test_filter: str) -> str:
"""Validate and sanitize a test filter string for Maven.

Test filters can contain commas (multiple classes) and wildcards (*).
This function validates the format to prevent command injection.

Args:
test_filter: The test filter string (e.g., "MyTest", "MyTest,OtherTest", "My*Test").

Returns:
The sanitized test filter.

Raises:
ValueError: If the test filter contains invalid characters.
"""
# Split by comma for multiple test patterns
patterns = [p.strip() for p in test_filter.split(',')]

for pattern in patterns:
# Remove wildcards for validation (they're allowed in test filters)
name_to_validate = pattern.replace('*', 'A') # Replace * with a valid char

if not _validate_java_class_name(name_to_validate):
raise ValueError(
f"Invalid test class name or pattern: '{pattern}'. "
f"Test names must follow Java identifier rules (letters, digits, underscores, dots, dollar signs)."
)

return test_filter


def _find_multi_module_root(project_root: Path, test_paths: Any) -> tuple[Path, str | None]:
"""Find the multi-module Maven parent root if tests are in a different module.
Expand Down Expand Up @@ -1053,7 +1103,9 @@ def _run_maven_tests(
cmd.extend(["-pl", test_module, "-am", "-DfailIfNoTests=false", "-DskipTests=false"])

if test_filter:
cmd.append(f"-Dtest={test_filter}")
# Validate test filter to prevent command injection
validated_filter = _validate_test_filter(test_filter)
cmd.append(f"-Dtest={validated_filter}")

logger.debug("Running Maven command: %s in %s", " ".join(cmd), project_root)

Expand Down Expand Up @@ -1333,6 +1385,16 @@ def get_test_run_command(
cmd = [mvn, "test"]

if test_classes:
cmd.append(f"-Dtest={','.join(test_classes)}")
# Validate each test class name to prevent command injection
validated_classes = []
for test_class in test_classes:
if not _validate_java_class_name(test_class):
raise ValueError(
f"Invalid test class name: '{test_class}'. "
f"Test names must follow Java identifier rules."
)
validated_classes.append(test_class)

cmd.append(f"-Dtest={','.join(validated_classes)}")

return cmd
Loading
Loading