diff --git a/codeflash/languages/java/build_tools.py b/codeflash/languages/java/build_tools.py index 3ba613729..200555488 100644 --- a/codeflash/languages/java/build_tools.py +++ b/codeflash/languages/java/build_tools.py @@ -18,6 +18,31 @@ logger = logging.getLogger(__name__) +def _safe_parse_xml(file_path: Path) -> ET.ElementTree: + """Safely parse an XML file with protections against XXE attacks. + + Args: + file_path: Path to the XML file. + + Returns: + Parsed ElementTree. + + Raises: + ET.ParseError: If XML parsing fails. + """ + # Read file content and parse as string to avoid file-based attacks + # This prevents XXE attacks by not allowing external entity resolution + content = file_path.read_text(encoding="utf-8") + + # Parse string content (no external entities possible) + root = ET.fromstring(content) + + # Create ElementTree from root + tree = ET.ElementTree(root) + + return tree + + class BuildTool(Enum): """Supported Java build tools.""" @@ -124,7 +149,7 @@ def _get_maven_project_info(project_root: Path) -> JavaProjectInfo | None: return None try: - tree = ET.parse(pom_path) + tree = _safe_parse_xml(pom_path) root = tree.getroot() # Handle Maven namespace @@ -438,16 +463,34 @@ def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]: for xml_file in surefire_dir.glob("TEST-*.xml"): try: - tree = ET.parse(xml_file) + tree = _safe_parse_xml(xml_file) root = tree.getroot() - tests_run += int(root.get("tests", 0)) - failures += int(root.get("failures", 0)) - errors += int(root.get("errors", 0)) - skipped += int(root.get("skipped", 0)) + # Safely parse numeric attributes with validation + try: + tests_run += int(root.get("tests", "0")) + except (ValueError, TypeError): + logger.warning("Invalid 'tests' value in %s, defaulting to 0", xml_file) + + try: + failures += int(root.get("failures", "0")) + except (ValueError, TypeError): + logger.warning("Invalid 'failures' value in %s, defaulting to 0", xml_file) + + try: + errors += int(root.get("errors", "0")) + except (ValueError, TypeError): + logger.warning("Invalid 'errors' value in %s, defaulting to 0", xml_file) + + try: + skipped += int(root.get("skipped", "0")) + except (ValueError, TypeError): + logger.warning("Invalid 'skipped' value in %s, defaulting to 0", xml_file) except ET.ParseError as e: logger.warning("Failed to parse Surefire report %s: %s", xml_file, e) + except Exception as e: + logger.warning("Unexpected error parsing Surefire report %s: %s", xml_file, e) return tests_run, failures, errors, skipped @@ -572,7 +615,7 @@ def add_codeflash_dependency_to_pom(pom_path: Path) -> bool: return False try: - tree = ET.parse(pom_path) + tree = _safe_parse_xml(pom_path) root = tree.getroot() # Handle Maven namespace @@ -647,7 +690,7 @@ def is_jacoco_configured(pom_path: Path) -> bool: return False try: - tree = ET.parse(pom_path) + tree = _safe_parse_xml(pom_path) root = tree.getroot() # Handle Maven namespace diff --git a/codeflash/languages/java/test_runner.py b/codeflash/languages/java/test_runner.py index 30ac7a321..5e40ec8bc 100644 --- a/codeflash/languages/java/test_runner.py +++ b/codeflash/languages/java/test_runner.py @@ -8,6 +8,7 @@ import logging import os +import re import shutil import subprocess import tempfile @@ -28,6 +29,55 @@ logger = logging.getLogger(__name__) +# Regex pattern for valid Java class names (package.ClassName format) +# Allows: letters, digits, underscores, dots, and dollar signs (inner classes) +_VALID_JAVA_CLASS_NAME = re.compile(r'^[a-zA-Z_$][a-zA-Z0-9_$.]*$') + + +def _validate_java_class_name(class_name: str) -> bool: + """Validate that a string is a valid Java class name. + + This prevents command injection when passing test class names to Maven. + + Args: + class_name: The class name to validate (e.g., "com.example.MyTest"). + + Returns: + True if valid, False otherwise. + """ + return bool(_VALID_JAVA_CLASS_NAME.match(class_name)) + + +def _validate_test_filter(test_filter: str) -> str: + """Validate and sanitize a test filter string for Maven. + + Test filters can contain commas (multiple classes) and wildcards (*). + This function validates the format to prevent command injection. + + Args: + test_filter: The test filter string (e.g., "MyTest", "MyTest,OtherTest", "My*Test"). + + Returns: + The sanitized test filter. + + Raises: + ValueError: If the test filter contains invalid characters. + """ + # Split by comma for multiple test patterns + patterns = [p.strip() for p in test_filter.split(',')] + + for pattern in patterns: + # Remove wildcards for validation (they're allowed in test filters) + name_to_validate = pattern.replace('*', 'A') # Replace * with a valid char + + if not _validate_java_class_name(name_to_validate): + raise ValueError( + f"Invalid test class name or pattern: '{pattern}'. " + f"Test names must follow Java identifier rules (letters, digits, underscores, dots, dollar signs)." + ) + + return test_filter + def _find_multi_module_root(project_root: Path, test_paths: Any) -> tuple[Path, str | None]: """Find the multi-module Maven parent root if tests are in a different module. @@ -1053,7 +1103,9 @@ def _run_maven_tests( cmd.extend(["-pl", test_module, "-am", "-DfailIfNoTests=false", "-DskipTests=false"]) if test_filter: - cmd.append(f"-Dtest={test_filter}") + # Validate test filter to prevent command injection + validated_filter = _validate_test_filter(test_filter) + cmd.append(f"-Dtest={validated_filter}") logger.debug("Running Maven command: %s in %s", " ".join(cmd), project_root) @@ -1333,6 +1385,16 @@ def get_test_run_command( cmd = [mvn, "test"] if test_classes: - cmd.append(f"-Dtest={','.join(test_classes)}") + # Validate each test class name to prevent command injection + validated_classes = [] + for test_class in test_classes: + if not _validate_java_class_name(test_class): + raise ValueError( + f"Invalid test class name: '{test_class}'. " + f"Test names must follow Java identifier rules." + ) + validated_classes.append(test_class) + + cmd.append(f"-Dtest={','.join(validated_classes)}") return cmd diff --git a/tests/test_languages/test_java/test_security.py b/tests/test_languages/test_java/test_security.py new file mode 100644 index 000000000..a1043a6f1 --- /dev/null +++ b/tests/test_languages/test_java/test_security.py @@ -0,0 +1,238 @@ +"""Tests for Java security and input validation.""" + +from pathlib import Path + +import pytest + +from codeflash.languages.java.test_runner import ( + _validate_java_class_name, + _validate_test_filter, + get_test_run_command, +) + + +class TestInputValidation: + """Tests for input validation to prevent command injection.""" + + def test_validate_java_class_name_valid(self): + """Test validation of valid Java class names.""" + valid_names = [ + "MyTest", + "com.example.MyTest", + "com.example.sub.MyTest", + "MyTest$InnerClass", + "_MyTest", + "$MyTest", + "Test123", + "com.example.Test_123", + ] + + for name in valid_names: + assert _validate_java_class_name(name), f"Should accept: {name}" + + def test_validate_java_class_name_invalid(self): + """Test rejection of invalid Java class names.""" + invalid_names = [ + "My Test", # Space + "My-Test", # Hyphen + "My;Test", # Semicolon (command injection) + "My&Test", # Ampersand (command injection) + "My|Test", # Pipe (command injection) + "My`Test", # Backtick (command injection) + "My$(whoami)Test", # Command substitution + "../../../etc/passwd", # Path traversal + "Test\nmalicious", # Newline + "", # Empty + ] + + for name in invalid_names: + assert not _validate_java_class_name(name), f"Should reject: {name}" + + def test_validate_test_filter_single_class(self): + """Test validation of single test class filter.""" + valid_filter = "com.example.MyTest" + result = _validate_test_filter(valid_filter) + assert result == valid_filter + + def test_validate_test_filter_multiple_classes(self): + """Test validation of multiple test classes.""" + valid_filter = "MyTest,OtherTest,com.example.ThirdTest" + result = _validate_test_filter(valid_filter) + assert result == valid_filter + + def test_validate_test_filter_wildcards(self): + """Test validation of wildcard patterns.""" + valid_patterns = [ + "My*Test", + "*Test", + "com.example.*Test", + "com.example.**", + ] + + for pattern in valid_patterns: + result = _validate_test_filter(pattern) + assert result == pattern, f"Should accept wildcard: {pattern}" + + def test_validate_test_filter_rejects_invalid(self): + """Test rejection of malicious test filters.""" + malicious_filters = [ + "Test;rm -rf /", + "Test&&whoami", + "Test|cat /etc/passwd", + "Test`whoami`", + "Test$(whoami)", + "../../../etc/passwd", + ] + + for malicious in malicious_filters: + with pytest.raises(ValueError, match="Invalid test class name"): + _validate_test_filter(malicious) + + def test_get_test_run_command_validates_input(self, tmp_path: Path): + """Test that get_test_run_command validates test class names.""" + # Valid class names should work + cmd = get_test_run_command(tmp_path, ["MyTest", "OtherTest"]) + assert "-Dtest=MyTest,OtherTest" in " ".join(cmd) + + # Invalid class names should raise ValueError + with pytest.raises(ValueError, match="Invalid test class name"): + get_test_run_command(tmp_path, ["My;Test"]) + + with pytest.raises(ValueError, match="Invalid test class name"): + get_test_run_command(tmp_path, ["Test$(whoami)"]) + + def test_special_characters_in_valid_java_names(self): + """Test that valid Java special characters are allowed.""" + # Dollar sign is valid (inner classes) + assert _validate_java_class_name("Outer$Inner") + + # Underscore is valid + assert _validate_java_class_name("_Private") + + # Numbers are valid (but not at start) + assert _validate_java_class_name("Test123") + + # Numbers at start are invalid + assert not _validate_java_class_name("123Test") + + +class TestXMLParsingSecurity: + """Tests for secure XML parsing.""" + + def test_parse_malformed_surefire_report(self, tmp_path: Path): + """Test handling of malformed XML in Surefire reports.""" + from codeflash.languages.java.build_tools import _parse_surefire_reports + + surefire_dir = tmp_path / "surefire-reports" + surefire_dir.mkdir() + + # Create a malformed XML file + malformed_xml = surefire_dir / "TEST-Malformed.xml" + malformed_xml.write_text("no closing tag") + + # Should not crash, should log warning and return 0 + tests_run, failures, errors, skipped = _parse_surefire_reports(surefire_dir) + assert tests_run == 0 + assert failures == 0 + assert errors == 0 + assert skipped == 0 + + def test_parse_surefire_report_invalid_numbers(self, tmp_path: Path): + """Test handling of invalid numeric attributes in XML.""" + from codeflash.languages.java.build_tools import _parse_surefire_reports + + surefire_dir = tmp_path / "surefire-reports" + surefire_dir.mkdir() + + # Create XML with invalid numeric values + invalid_xml = surefire_dir / "TEST-Invalid.xml" + invalid_xml.write_text(""" + + + +""") + + # Should handle gracefully and default to 0 + tests_run, failures, errors, skipped = _parse_surefire_reports(surefire_dir) + assert tests_run == 0 # Invalid "abc" defaulted to 0 + assert failures == 0 # Invalid "xyz" defaulted to 0 + assert errors == 0 # Invalid "foo" defaulted to 0 + assert skipped == 0 # Invalid "bar" defaulted to 0 + + def test_parse_valid_surefire_report(self, tmp_path: Path): + """Test parsing of valid Surefire report.""" + from codeflash.languages.java.build_tools import _parse_surefire_reports + + surefire_dir = tmp_path / "surefire-reports" + surefire_dir.mkdir() + + # Create valid XML + valid_xml = surefire_dir / "TEST-Valid.xml" + valid_xml.write_text(""" + + + + Expected true but was false + + + NullPointerException + + + IllegalArgumentException + + + + + +""") + + tests_run, failures, errors, skipped = _parse_surefire_reports(surefire_dir) + assert tests_run == 5 + assert failures == 1 + assert errors == 2 + assert skipped == 1 + + def test_parse_multiple_surefire_reports(self, tmp_path: Path): + """Test parsing of multiple Surefire reports.""" + from codeflash.languages.java.build_tools import _parse_surefire_reports + + surefire_dir = tmp_path / "surefire-reports" + surefire_dir.mkdir() + + # Create multiple valid XML files + for i in range(3): + xml_file = surefire_dir / f"TEST-Suite{i}.xml" + xml_file.write_text(f""" + + + +""") + + tests_run, failures, errors, skipped = _parse_surefire_reports(surefire_dir) + assert tests_run == 1 + 2 + 3 # Sum of all tests + assert failures == 0 + assert errors == 0 + assert skipped == 0 + + +class TestErrorHandling: + """Tests for robust error handling.""" + + def test_empty_test_class_name(self): + """Test handling of empty test class name.""" + assert not _validate_java_class_name("") + + def test_whitespace_test_class_name(self): + """Test handling of whitespace-only test class name.""" + assert not _validate_java_class_name(" ") + + def test_test_filter_with_spaces(self): + """Test handling of test filter with spaces (should be rejected).""" + with pytest.raises(ValueError): + _validate_test_filter("My Test") + + def test_test_filter_empty_after_split(self): + """Test handling of empty patterns after comma split.""" + # Empty patterns between commas should raise ValueError + with pytest.raises(ValueError, match="Invalid test class name"): + _validate_test_filter("Test1,,Test2")