diff --git a/code_to_optimize/js/code_to_optimize_js/bubble_sort.js b/code_to_optimize/js/code_to_optimize_js/bubble_sort.js index 8f3c9ffca..fe63d82dc 100644 --- a/code_to_optimize/js/code_to_optimize_js/bubble_sort.js +++ b/code_to_optimize/js/code_to_optimize_js/bubble_sort.js @@ -11,14 +11,21 @@ function bubbleSort(arr) { const result = arr.slice(); const n = result.length; - for (let i = 0; i < n; i++) { - for (let j = 0; j < n - 1; j++) { - if (result[j] > result[j + 1]) { - const temp = result[j]; - result[j] = result[j + 1]; - result[j + 1] = temp; + if (n <= 1) return result; + + for (let i = 0; i < n - 1; i++) { + let swapped = false; + const limit = n - i - 1; + for (let j = 0; j < limit; j++) { + const a = result[j]; + const b = result[j + 1]; + if (a > b) { + result[j] = b; + result[j + 1] = a; + swapped = true; } } + if (!swapped) break; } return result; diff --git a/code_to_optimize/js/code_to_optimize_vitest/package-lock.json b/code_to_optimize/js/code_to_optimize_vitest/package-lock.json index ac3d39afd..ef24dc459 100644 --- a/code_to_optimize/js/code_to_optimize_vitest/package-lock.json +++ b/code_to_optimize/js/code_to_optimize_vitest/package-lock.json @@ -15,7 +15,7 @@ } }, "../../../packages/codeflash": { - "version": "0.7.0", + "version": "0.8.0", "dev": true, "hasInstallScript": true, "license": "MIT", diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 95fc5d506..7a9afc96f 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -37,21 +37,6 @@ def is_glob_pattern(path_str: str) -> bool: def normalize_ignore_paths(paths: list[str], base_path: Path | None = None) -> list[Path]: - """Normalize ignore paths, expanding glob patterns and resolving paths. - - Accepts a list of path strings that can be either: - - Literal paths (relative or absolute): e.g., "node_modules", "/absolute/path" - - Glob patterns: e.g., "**/*.test.js", "dist/*", "*.log" - - Args: - paths: List of path strings (literal paths or glob patterns). - base_path: Base path for resolving relative paths and patterns. - If None, uses current working directory. - - Returns: - List of resolved Path objects, deduplicated. - - """ if base_path is None: base_path = Path.cwd() @@ -59,22 +44,25 @@ def normalize_ignore_paths(paths: list[str], base_path: Path | None = None) -> l normalized: set[Path] = set() for path_str in paths: + if not path_str: + continue + + path_str = str(path_str) + if is_glob_pattern(path_str): - # It's a glob pattern - expand it - # Use base_path as the root for glob expansion - pattern_path = base_path / path_str - # glob returns an iterator of matching paths + # pathlib requires relative glob patterns + path_str = path_str.removeprefix("./") + if path_str.startswith("/"): + path_str = path_str.lstrip("/") + for matched_path in base_path.glob(path_str): - if matched_path.exists(): - normalized.add(matched_path.resolve()) + normalized.add(matched_path.resolve()) else: - # It's a literal path path_obj = Path(path_str) if not path_obj.is_absolute(): path_obj = base_path / path_obj if path_obj.exists(): normalized.add(path_obj.resolve()) - # Silently skip non-existent literal paths (e.g., .next, dist before build) return list(normalized) diff --git a/codeflash/code_utils/time_utils.py b/codeflash/code_utils/time_utils.py index e44c279d3..ff04b5037 100644 --- a/codeflash/code_utils/time_utils.py +++ b/codeflash/code_utils/time_utils.py @@ -1,10 +1,5 @@ from __future__ import annotations -import datetime as dt -import re - -import humanize - def humanize_runtime(time_in_ns: int) -> str: runtime_human: str = str(time_in_ns) @@ -14,22 +9,32 @@ def humanize_runtime(time_in_ns: int) -> str: if time_in_ns / 1000 >= 1: time_micro = float(time_in_ns) / 1000 - runtime_human = humanize.precisedelta(dt.timedelta(microseconds=time_micro), minimum_unit="microseconds") - units = re.split(r",|\s", runtime_human)[1] - - if units in {"microseconds", "microsecond"}: + # Direct unit determination and formatting without external library + if time_micro < 1000: runtime_human = f"{time_micro:.3g}" - elif units in {"milliseconds", "millisecond"}: - runtime_human = "%.3g" % (time_micro / 1000) - elif units in {"seconds", "second"}: - runtime_human = "%.3g" % (time_micro / (1000**2)) - elif units in {"minutes", "minute"}: - runtime_human = "%.3g" % (time_micro / (60 * 1000**2)) - elif units in {"hour", "hours"}: # hours - runtime_human = "%.3g" % (time_micro / (3600 * 1000**2)) + units = "microseconds" if time_micro >= 2 else "microsecond" + elif time_micro < 1000000: + time_milli = time_micro / 1000 + runtime_human = f"{time_milli:.3g}" + units = "milliseconds" if time_milli >= 2 else "millisecond" + elif time_micro < 60000000: + time_sec = time_micro / 1000000 + runtime_human = f"{time_sec:.3g}" + units = "seconds" if time_sec >= 2 else "second" + elif time_micro < 3600000000: + time_min = time_micro / 60000000 + runtime_human = f"{time_min:.3g}" + units = "minutes" if time_min >= 2 else "minute" + elif time_micro < 86400000000: + time_hour = time_micro / 3600000000 + runtime_human = f"{time_hour:.3g}" + units = "hours" if time_hour >= 2 else "hour" else: # days - runtime_human = "%.3g" % (time_micro / (24 * 3600 * 1000**2)) + time_day = time_micro / 86400000000 + runtime_human = f"{time_day:.3g}" + units = "days" if time_day >= 2 else "day" + runtime_human_parts = str(runtime_human).split(".") if len(runtime_human_parts[0]) == 1: if runtime_human_parts[0] == "1" and len(runtime_human_parts) > 1: diff --git a/codeflash/github/PrComment.py b/codeflash/github/PrComment.py index ffba759b5..85c24ec57 100644 --- a/codeflash/github/PrComment.py +++ b/codeflash/github/PrComment.py @@ -45,8 +45,8 @@ def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[B } if self.original_async_throughput is not None and self.best_async_throughput is not None: - result["original_async_throughput"] = str(self.original_async_throughput) - result["best_async_throughput"] = str(self.best_async_throughput) + result["original_async_throughput"] = self.original_async_throughput + result["best_async_throughput"] = self.best_async_throughput return result diff --git a/codeflash/languages/javascript/instrument.py b/codeflash/languages/javascript/instrument.py index 938c160aa..dee534044 100644 --- a/codeflash/languages/javascript/instrument.py +++ b/codeflash/languages/javascript/instrument.py @@ -56,6 +56,46 @@ class StandaloneCallMatch: ) +def is_inside_string(code: str, pos: int) -> bool: + """Check if a position in code is inside a string literal. + + Handles single quotes, double quotes, and template literals (backticks). + Properly handles escaped quotes. + + Args: + code: The source code. + pos: The position to check. + + Returns: + True if the position is inside a string literal. + + """ + in_string = False + string_char = None + i = 0 + + while i < pos: + char = code[i] + + if in_string: + # Check for escape sequence + if char == "\\" and i + 1 < len(code): + i += 2 # Skip escaped character + continue + # Check for end of string + if char == string_char: + in_string = False + string_char = None + # Check for start of string + elif char in "\"'`": + in_string = True + string_char = char + + i += 1 + + return in_string + + class StandaloneCallTransformer: """Transforms standalone func(...) calls in JavaScript test code. @@ -82,6 +122,11 @@ def __init__(self, function_to_optimize: FunctionToOptimize, capture_func: str) # Captures: (whitespace)(await )?(object.)*func_name( # We'll filter out expect() and codeflash. cases in the transform loop self._call_pattern = re.compile(rf"(\s*)(await\s+)?((?:\w+\.)*){re.escape(self.func_name)}\s*\(") + # Pattern to match bracket notation: obj['func_name']( or obj["func_name"]( + # Captures: (whitespace)(await )?(obj)['|"]func_name['|"]( + self._bracket_call_pattern = re.compile( + rf"(\s*)(await\s+)?(\w+)\[['\"]({re.escape(self.func_name)})['\"]]\s*\(" + ) def transform(self, code: str) -> str: """Transform all standalone calls in the code.""" @@ -89,7 +134,25 @@ def transform(self, code: str) -> str: pos = 0 while pos < len(code): - match = self._call_pattern.search(code, pos) + # Try both dot notation and bracket notation patterns + dot_match = self._call_pattern.search(code, pos) + bracket_match = self._bracket_call_pattern.search(code, pos) + + # Choose the first match (by position) + match = None + is_bracket_notation = False + if dot_match and bracket_match: + if dot_match.start() <= bracket_match.start(): + match = dot_match + else: + match = bracket_match + is_bracket_notation = True + elif dot_match: + match = dot_match + elif bracket_match: + match = bracket_match + is_bracket_notation = True + if not match: result.append(code[pos:]) break @@ -106,7 +169,11 @@ def transform(self, code: str) -> str: result.append(code[pos:match_start]) # Try to parse the full standalone call - standalone_match = self._parse_standalone_call(code, match) + if is_bracket_notation: + standalone_match = self._parse_bracket_standalone_call(code, match) + else: + standalone_match = self._parse_standalone_call(code, match) + if standalone_match is None: # Couldn't parse, skip this match result.append(code[match_start : match.end()]) @@ -115,7 +182,7 @@ def transform(self, code: str) -> str: # Generate the transformed code self.invocation_counter += 1 - transformed = self._generate_transformed_call(standalone_match) + transformed = self._generate_transformed_call(standalone_match, is_bracket_notation) result.append(transformed) pos = standalone_match.end_pos @@ -123,6 +190,10 @@ def transform(self, code: str) -> str: def _should_skip_match(self, code: str, start: int, match: re.Match) -> bool: """Check if the match should be skipped (inside expect, already transformed, etc.).""" + # Skip if inside a string literal (e.g., test description) + if is_inside_string(code, start): + return True + # Look backwards to check context lookback_start = max(0, start - 200) lookback = code[lookback_start:start] @@ -252,17 +323,24 @@ def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | N in_string = False string_char = None - while pos < len(code) and depth > 0: - char = code[pos] + s = code # local alias for speed + s_len = len(s) + quotes = "\"'`" + + while pos < s_len and depth > 0: + char = s[pos] # Handle string literals - if char in "\"'`" and (pos == 0 or code[pos - 1] != "\\"): - if not in_string: - in_string = True - string_char = char - elif char == string_char: - in_string = False - string_char = None + # Note: preserve original escaping semantics (only checks immediate preceding char) + if char in quotes: + prev_char = s[pos - 1] if pos > 0 else None + if prev_char != "\\": + if not in_string: + in_string = True + string_char = char + elif char == string_char: + in_string = False + string_char = None elif not in_string: if char == "(": depth += 1 @@ -274,19 +352,64 @@ def _find_balanced_parens(self, code: str, open_paren_pos: int) -> tuple[str | N if depth != 0: return None, -1 - return code[open_paren_pos + 1 : pos - 1], pos + # slice once + return s[open_paren_pos + 1 : pos - 1], pos + + def _parse_bracket_standalone_call(self, code: str, match: re.Match) -> StandaloneCallMatch | None: + """Parse a complete standalone obj['func'](...) call with bracket notation.""" + leading_ws = match.group(1) + prefix = match.group(2) or "" # "await " or "" + obj_name = match.group(3) # The object name before bracket + # match.group(4) is the function name inside brackets + + # Find the opening paren position + match_text = match.group(0) + paren_offset = match_text.rfind("(") + open_paren_pos = match.start() + paren_offset + + # Find the arguments (content inside parens) + func_args, close_pos = self._find_balanced_parens(code, open_paren_pos) + if func_args is None: + return None + + # Check for trailing semicolon + end_pos = close_pos + # Skip whitespace + s = code + s_len = len(s) + while end_pos < s_len and s[end_pos] in " \t": + end_pos += 1 + + has_trailing_semicolon = end_pos < s_len and s[end_pos] == ";" + if has_trailing_semicolon: + end_pos += 1 + + return StandaloneCallMatch( + start_pos=match.start(), + end_pos=end_pos, + leading_whitespace=leading_ws, + func_args=func_args, + prefix=prefix, + object_prefix=f"{obj_name}.", # Use dot notation format for consistency + has_trailing_semicolon=has_trailing_semicolon, + ) - def _generate_transformed_call(self, match: StandaloneCallMatch) -> str: + def _generate_transformed_call(self, match: StandaloneCallMatch, is_bracket_notation: bool = False) -> str: """Generate the transformed code for a standalone call.""" line_id = str(self.invocation_counter) args_str = match.func_args.strip() semicolon = ";" if match.has_trailing_semicolon else "" - # Handle method calls on objects (e.g., calc.fibonacci, this.method) + # Handle method calls on objects (e.g., calc.fibonacci, this.method, instance['method']) if match.object_prefix: # Remove trailing dot from object prefix for the bind call obj = match.object_prefix.rstrip(".") - full_method = f"{obj}.{self.func_name}" + + # For bracket notation, use bracket access syntax for the bind + if is_bracket_notation: + full_method = f"{obj}['{self.func_name}']" + else: + full_method = f"{obj}.{self.func_name}" if args_str: return ( @@ -370,6 +493,12 @@ def transform(self, code: str) -> str: result.append(code[pos:]) break + # Skip if inside a string literal (e.g., test description) + if is_inside_string(code, match.start()): + result.append(code[pos : match.end()]) + pos = match.end() + continue + # Add everything before the match result.append(code[pos : match.start()]) @@ -1071,3 +1200,173 @@ def instrument_generated_js_test( mode=mode, remove_assertions=True, ) + + +def fix_imports_inside_test_blocks(test_code: str) -> str: + """Fix import statements that appear inside test/it blocks. + + JavaScript/TypeScript `import` statements must be at the top level of a module. + The AI sometimes generates imports inside test functions, which is invalid syntax. + + This function detects such patterns and converts them to dynamic require() calls + which are valid inside functions. + + Args: + test_code: The generated test code. + + Returns: + Fixed test code with imports converted to require() inside functions. + + """ + if not test_code or not test_code.strip(): + return test_code + + # Pattern to match import statements inside functions + # This captures imports that appear after function/test block openings + # We look for lines that: + # 1. Start with whitespace (indicating they're inside a block) + # 2. Have an import statement + + lines = test_code.split("\n") + result_lines = [] + brace_depth = 0 + in_test_block = False + + for line in lines: + stripped = line.strip() + + # Track brace depth to know if we're inside a block + # Count braces, but ignore braces in strings (simplified check) + for char in stripped: + if char == "{": + brace_depth += 1 + elif char == "}": + brace_depth -= 1 + + # Check if we're entering a test/it/describe block + if re.match(r"^(test|it|describe|beforeEach|afterEach|beforeAll|afterAll)\s*\(", stripped): + in_test_block = True + + # Check for import statement inside a block (brace_depth > 0 means we're inside a function/block) + if brace_depth > 0 and stripped.startswith("import "): + # Convert ESM import to require + # Pattern: import { name } from 'module' -> const { name } = require('module') + # Pattern: import name from 'module' -> const name = require('module') + + named_import = re.match(r"import\s+\{([^}]+)\}\s+from\s+['\"]([^'\"]+)['\"]", stripped) + default_import = re.match(r"import\s+(\w+)\s+from\s+['\"]([^'\"]+)['\"]", stripped) + namespace_import = re.match(r"import\s+\*\s+as\s+(\w+)\s+from\s+['\"]([^'\"]+)['\"]", stripped) + + leading_whitespace = line[: len(line) - len(line.lstrip())] + + if named_import: + names = named_import.group(1) + module = named_import.group(2) + new_line = f"{leading_whitespace}const {{{names}}} = require('{module}');" + result_lines.append(new_line) + logger.debug(f"Fixed import inside block: {stripped} -> {new_line.strip()}") + continue + if default_import: + name = default_import.group(1) + module = default_import.group(2) + new_line = f"{leading_whitespace}const {name} = require('{module}');" + result_lines.append(new_line) + logger.debug(f"Fixed import inside block: {stripped} -> {new_line.strip()}") + continue + if namespace_import: + name = namespace_import.group(1) + module = namespace_import.group(2) + new_line = f"{leading_whitespace}const {name} = require('{module}');" + result_lines.append(new_line) + logger.debug(f"Fixed import inside block: {stripped} -> {new_line.strip()}") + continue + + result_lines.append(line) + + return "\n".join(result_lines) + + +def fix_jest_mock_paths(test_code: str, test_file_path: Path, source_file_path: Path, tests_root: Path) -> str: + """Fix relative paths in jest.mock() calls to be correct from the test file's location. + + The AI sometimes generates jest.mock() calls with paths relative to the source file + instead of the test file. For example: + - Source at `src/queue/queue.ts` imports `../environment` (-> src/environment) + - Test at `tests/test.test.ts` generates `jest.mock('../environment')` (-> ./environment, wrong!) + - Should generate `jest.mock('../src/environment')` + + This function detects relative mock paths and adjusts them based on the test file's + location relative to the source file's directory. + + Args: + test_code: The generated test code. + test_file_path: Path to the test file being generated. + source_file_path: Path to the source file being tested. + tests_root: Root directory of the tests. + + Returns: + Fixed test code with corrected mock paths. + + """ + if not test_code or not test_code.strip(): + return test_code + + import os + + # Get the directory containing the source file and the test file + source_dir = source_file_path.resolve().parent + test_dir = test_file_path.resolve().parent + project_root = tests_root.resolve().parent if tests_root.name == "tests" else tests_root.resolve() + + # Pattern to match jest.mock() or jest.doMock() with relative paths + mock_pattern = re.compile(r"(jest\.(?:mock|doMock)\s*\(\s*['\"])(\.\./[^'\"]+|\.\/[^'\"]+)(['\"])") + + def fix_mock_path(match: re.Match[str]) -> str: + original = match.group(0) + prefix = match.group(1) + rel_path = match.group(2) + suffix = match.group(3) + + # Resolve the path as if it were relative to the source file's directory + # (which is how the AI often generates it) + source_relative_resolved = (source_dir / rel_path).resolve() + + # Check if this resolved path exists or if adjusting it would make more sense + # Calculate what the correct relative path from the test file should be + try: + # First, try to find if the path makes sense from the test directory + test_relative_resolved = (test_dir / rel_path).resolve() + + # If the path exists relative to test dir, keep it + if test_relative_resolved.exists() or ( + test_relative_resolved.with_suffix(".ts").exists() + or test_relative_resolved.with_suffix(".js").exists() + or test_relative_resolved.with_suffix(".tsx").exists() + or test_relative_resolved.with_suffix(".jsx").exists() + ): + return original # Keep original, it's valid + + # If path exists relative to source dir, recalculate from test dir + if source_relative_resolved.exists() or ( + source_relative_resolved.with_suffix(".ts").exists() + or source_relative_resolved.with_suffix(".js").exists() + or source_relative_resolved.with_suffix(".tsx").exists() + or source_relative_resolved.with_suffix(".jsx").exists() + ): + # Calculate the correct relative path from test_dir to source_relative_resolved + new_rel_path = os.path.relpath(str(source_relative_resolved), str(test_dir)) + # Ensure it starts with ./ or ../ + if not new_rel_path.startswith("../") and not new_rel_path.startswith("./"): + new_rel_path = f"./{new_rel_path}" + # Use forward slashes + new_rel_path = new_rel_path.replace("\\", "/") + + logger.debug(f"Fixed jest.mock path: {rel_path} -> {new_rel_path}") + return f"{prefix}{new_rel_path}{suffix}" + + except (ValueError, OSError): + pass # Path resolution failed, keep original + + return original # Keep original if we can't fix it + + return mock_pattern.sub(fix_mock_path, test_code) diff --git a/codeflash/languages/javascript/module_system.py b/codeflash/languages/javascript/module_system.py index 66e6fe7e3..89d723c02 100644 --- a/codeflash/languages/javascript/module_system.py +++ b/codeflash/languages/javascript/module_system.py @@ -100,23 +100,40 @@ def detect_module_system(project_root: Path, file_path: Path | None = None) -> s try: content = file_path.read_text() - # Look for ES module syntax + # Look for ES module syntax - these are explicit ESM markers has_import = "import " in content and "from " in content - has_export = "export " in content or "export default" in content or "export {" in content + # Check for export function/class/const/default which are unambiguous ESM syntax + has_esm_export = ( + "export function " in content + or "export class " in content + or "export const " in content + or "export let " in content + or "export default " in content + or "export async function " in content + ) + has_export_block = "export {" in content # Look for CommonJS syntax has_require = "require(" in content has_module_exports = "module.exports" in content or "exports." in content - # Determine based on what we found - if (has_import or has_export) and not (has_require or has_module_exports): - logger.debug("Detected ES Module from import/export statements") + # Prioritize ESM when explicit ESM export syntax is found + # This handles hybrid files that have both `export function` and `module.exports` + # The ESM syntax is more explicit and should take precedence + if has_esm_export or has_import: + logger.debug("Detected ES Module from explicit export/import statements") return ModuleSystem.ES_MODULE - if (has_require or has_module_exports) and not (has_import or has_export): + # Pure CommonJS + if (has_require or has_module_exports) and not has_export_block: logger.debug("Detected CommonJS from require/module.exports") return ModuleSystem.COMMONJS + # Export block without other ESM markers - still ESM + if has_export_block: + logger.debug("Detected ES Module from export block") + return ModuleSystem.ES_MODULE + except Exception as e: logger.warning("Failed to analyze file %s: %s", file_path, e) diff --git a/codeflash/languages/javascript/parse.py b/codeflash/languages/javascript/parse.py index 1bfda8bca..a5e7ae8c6 100644 --- a/codeflash/languages/javascript/parse.py +++ b/codeflash/languages/javascript/parse.py @@ -198,14 +198,20 @@ def parse_jest_test_xml( # Extract console output from suite-level system-out (Jest specific) suite_stdout = _extract_jest_console_output(suite._elem) # noqa: SLF001 - # Fallback: use subprocess stdout if XML system-out is empty - if not suite_stdout and global_stdout: - suite_stdout = global_stdout + # Combine suite stdout with global stdout to ensure we capture all timing markers + # Jest-junit may not capture all console.log output in the XML, so we also need + # to check the subprocess stdout directly for timing markers + combined_stdout = suite_stdout + if global_stdout: + if combined_stdout: + combined_stdout = combined_stdout + "\n" + global_stdout + else: + combined_stdout = global_stdout - # Parse timing markers from the suite's console output - start_matches = list(jest_start_pattern.finditer(suite_stdout)) + # Parse timing markers from the combined console output + start_matches = list(jest_start_pattern.finditer(combined_stdout)) end_matches_dict = {} - for match in jest_end_pattern.finditer(suite_stdout): + for match in jest_end_pattern.finditer(combined_stdout): # Key: (testName, testName2, funcName, loopIndex, lineId) key = match.groups()[:5] end_matches_dict[key] = match @@ -318,7 +324,7 @@ def parse_jest_test_xml( # Infer test type from filename pattern filename = test_file_path.name if "__perf_test_" in filename or "_perf_test_" in filename: - test_type = TestType.GENERATED_PERFORMANCE + test_type = TestType.GENERATED_REGRESSION # Performance tests are still generated regression tests elif "__unit_test_" in filename or "_unit_test_" in filename: test_type = TestType.GENERATED_REGRESSION else: diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index d32cce001..20fe29573 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -104,6 +104,12 @@ def discover_functions( if not criteria.include_async and func.is_async: continue + # Skip non-exported functions (can't be imported in tests) + # Exception: nested functions and methods are allowed if their parent is exported + if not func.is_exported and not func.parent_function: + logger.debug(f"Skipping non-exported function: {func.name}") # noqa: G004 + continue + # Build parents list parents: list[FunctionParent] = [] if func.class_name: @@ -326,8 +332,14 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, else: target_code = "" + imports = analyzer.find_imports(source) + + # Find helper functions called by target (needed before class wrapping to find same-class helpers) + helpers = self._find_helper_functions(function, source, analyzer, imports, module_root) + # For class methods, wrap the method in its class definition # This is necessary because method definition syntax is only valid inside a class body + same_class_helper_names: set[str] = set() if function.is_method and function.parents: class_name = None for parent in function.parents: @@ -336,17 +348,26 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, break if class_name: + # Find same-class helper methods that need to be included inside the class wrapper + same_class_helpers = self._find_same_class_helpers( + class_name, function.function_name, helpers, tree_functions, lines + ) + same_class_helper_names = {h[0] for h in same_class_helpers} # method names + # Find the class definition in the source to get proper indentation, JSDoc, constructor, and fields class_info = self._find_class_definition(source, class_name, analyzer, function.function_name) if class_info: class_jsdoc, class_indent, constructor_code, fields_code = class_info - # Build the class body with fields, constructor, and target method + # Build the class body with fields, constructor, target method, and same-class helpers class_body_parts = [] if fields_code: class_body_parts.append(fields_code) if constructor_code: class_body_parts.append(constructor_code) class_body_parts.append(target_code) + # Add same-class helper methods inside the class body + for _helper_name, helper_source in same_class_helpers: + class_body_parts.append(helper_source) class_body = "\n".join(class_body_parts) # Wrap the method in a class definition with context @@ -357,13 +378,16 @@ def extract_code_context(self, function: FunctionToOptimize, project_root: Path, else: target_code = f"{class_indent}class {class_name} {{\n{class_body}{class_indent}}}\n" else: - # Fallback: wrap with no indentation - target_code = f"class {class_name} {{\n{target_code}}}\n" - - imports = analyzer.find_imports(source) + # Fallback: wrap with no indentation, including same-class helpers + helper_code = "\n".join(h[1] for h in same_class_helpers) + if helper_code: + target_code = f"class {class_name} {{\n{target_code}\n{helper_code}}}\n" + else: + target_code = f"class {class_name} {{\n{target_code}}}\n" - # Find helper functions called by target - helpers = self._find_helper_functions(function, source, analyzer, imports, module_root) + # Filter out same-class helpers from the helpers list (they're already inside the class wrapper) + if same_class_helper_names: + helpers = [h for h in helpers if h.name not in same_class_helper_names] # Extract import statements as strings import_lines = [] @@ -546,6 +570,49 @@ def _extract_class_context( return (constructor_code, fields_code) + def _find_same_class_helpers( + self, + class_name: str, + target_method_name: str, + helpers: list[HelperFunction], + tree_functions: list, + lines: list[str], + ) -> list[tuple[str, str]]: + """Find helper methods that belong to the same class as the target method. + + These helpers need to be included inside the class wrapper rather than + appended outside, because they may use class-specific syntax like 'private'. + + Args: + class_name: Name of the class containing the target method. + target_method_name: Name of the target method (to exclude). + helpers: List of all helper functions found. + tree_functions: List of FunctionNode from tree-sitter analysis. + lines: Source code split into lines. + + Returns: + List of (method_name, source_code) tuples for same-class helpers. + + """ + same_class_helpers: list[tuple[str, str]] = [] + + # Build a set of helper names for quick lookup + helper_names = {h.name for h in helpers} + + # Names to exclude from same-class helpers (target method and constructor) + exclude_names = {target_method_name, "constructor"} + + # Find methods in tree_functions that belong to the same class and are helpers + for func in tree_functions: + if func.class_name == class_name and func.name in helper_names and func.name not in exclude_names: + # Extract source including JSDoc if present + effective_start = func.doc_start_line or func.start_line + helper_lines = lines[effective_start - 1 : func.end_line] + helper_source = "".join(helper_lines) + same_class_helpers.append((func.name, helper_source)) + + return same_class_helpers + def _find_helper_functions( self, function: FunctionToOptimize, diff --git a/codeflash/languages/javascript/test_runner.py b/codeflash/languages/javascript/test_runner.py index c65adfa7b..1d79ad382 100644 --- a/codeflash/languages/javascript/test_runner.py +++ b/codeflash/languages/javascript/test_runner.py @@ -7,6 +7,7 @@ from __future__ import annotations import json +import os import subprocess import time from pathlib import Path @@ -21,6 +22,25 @@ if TYPE_CHECKING: from codeflash.models.models import TestFiles +# Track created config files (jest configs and tsconfigs) for cleanup +_created_config_files: set[Path] = set() + + +def get_created_config_files() -> list[Path]: + """Get list of config files created by codeflash for cleanup. + + Returns: + List of paths to created config files (jest.codeflash.config.js, tsconfig.codeflash.json) + that should be cleaned up after optimization. + + """ + return list(_created_config_files) + + +def clear_created_config_files() -> None: + """Clear the set of tracked config files after cleanup.""" + _created_config_files.clear() + def _detect_bundler_module_resolution(project_root: Path) -> bool: """Detect if the project uses moduleResolution: 'bundler' in tsconfig. @@ -163,6 +183,7 @@ def _create_codeflash_tsconfig(project_root: Path) -> Path: try: codeflash_tsconfig_path.write_text(json.dumps(codeflash_tsconfig, indent=2)) + _created_config_files.add(codeflash_tsconfig_path) logger.debug(f"Created {codeflash_tsconfig_path} with Node moduleResolution") except Exception as e: logger.warning(f"Failed to create codeflash tsconfig: {e}") @@ -170,70 +191,142 @@ def _create_codeflash_tsconfig(project_root: Path) -> Path: return codeflash_tsconfig_path -def _create_codeflash_jest_config(project_root: Path, original_jest_config: Path | None) -> Path | None: - """Create a Jest config that uses the codeflash tsconfig for ts-jest. +def _has_ts_jest_dependency(project_root: Path) -> bool: + """Check if the project has ts-jest as a dependency. + + Args: + project_root: Root of the project. + + Returns: + True if ts-jest is found in dependencies or devDependencies. + + """ + package_json = project_root / "package.json" + if not package_json.exists(): + return False + + try: + content = json.loads(package_json.read_text()) + deps = {**content.get("dependencies", {}), **content.get("devDependencies", {})} + return "ts-jest" in deps + except (json.JSONDecodeError, OSError): + return False + + +def _create_codeflash_jest_config( + project_root: Path, original_jest_config: Path | None, *, for_esm: bool = False +) -> Path | None: + """Create a Jest config that handles ESM packages and TypeScript properly. Args: project_root: Root of the project. original_jest_config: Path to the original Jest config, or None. + for_esm: If True, configure for ESM package transformation. Returns: Path to the codeflash Jest config, or None if creation failed. """ - codeflash_jest_config_path = project_root / "jest.codeflash.config.js" + # For ESM projects (type: module), use .cjs extension since config uses CommonJS require/module.exports + # This prevents "ReferenceError: module is not defined" errors + is_esm = _is_esm_project(project_root) + config_ext = ".cjs" if is_esm else ".js" - # If it already exists, use it + # Create codeflash config in the same directory as the original config + # This ensures relative paths work correctly + if original_jest_config: + codeflash_jest_config_path = original_jest_config.parent / f"jest.codeflash.config{config_ext}" + else: + codeflash_jest_config_path = project_root / f"jest.codeflash.config{config_ext}" + + # If it already exists, use it (check both extensions) if codeflash_jest_config_path.exists(): logger.debug(f"Using existing {codeflash_jest_config_path}") return codeflash_jest_config_path - # Create a wrapper Jest config that uses tsconfig.codeflash.json + # Also check if the alternate extension exists + alt_ext = ".js" if is_esm else ".cjs" + alt_path = codeflash_jest_config_path.with_suffix(alt_ext) + if alt_path.exists(): + logger.debug(f"Using existing {alt_path}") + return alt_path + + # Common ESM-only packages that need to be transformed + # These packages ship only ESM and will cause "Cannot use import statement" errors + esm_packages = [ + "p-queue", + "p-limit", + "p-timeout", + "yocto-queue", + "eventemitter3", + "chalk", + "ora", + "strip-ansi", + "ansi-regex", + "string-width", + "wrap-ansi", + "is-unicode-supported", + "is-interactive", + "log-symbols", + "figures", + ] + esm_pattern = "|".join(esm_packages) + + # Check if ts-jest is available in the project + has_ts_jest = _has_ts_jest_dependency(project_root) + + # Build transform config only if ts-jest is available + if has_ts_jest: + transform_config = """ + // Ensure TypeScript files are transformed using ts-jest + transform: { + '^.+\\\\.(ts|tsx)$': ['ts-jest', { isolatedModules: true }], + // Use ts-jest for JS files in ESM packages too + '^.+\\\\.js$': ['ts-jest', { isolatedModules: true }], + },""" + else: + transform_config = "" + logger.debug("ts-jest not found in project dependencies, skipping transform config") + + # Create a wrapper Jest config if original_jest_config: - # Extend the original config - jest_config_content = f"""// Auto-generated by codeflash for bundler moduleResolution compatibility -const originalConfig = require('./{original_jest_config.name}'); + # Since codeflash config is in the same directory as original, use simple relative path + config_require_path = f"./{original_jest_config.name}" -const tsJestOptions = {{ - isolatedModules: true, - tsconfig: 'tsconfig.codeflash.json', -}}; + # Extend the original config + jest_config_content = f"""// Auto-generated by codeflash for ESM compatibility +const originalConfig = require('{config_require_path}'); module.exports = {{ ...originalConfig, - transform: {{ - ...originalConfig.transform, - '^.+\\\\.tsx?$': ['ts-jest', tsJestOptions], - }}, - globals: {{ - ...originalConfig.globals, - 'ts-jest': tsJestOptions, - }}, + // Transform ESM packages that don't work with Jest's default config + // Pattern handles both npm/yarn (node_modules/pkg) and pnpm (node_modules/.pnpm/pkg@version/node_modules/pkg) + transformIgnorePatterns: [ + 'node_modules/(?!(\\\\.pnpm/)?({esm_pattern}))', + ],{transform_config} }}; """ else: - # Create a minimal Jest config for TypeScript - jest_config_content = """// Auto-generated by codeflash for bundler moduleResolution compatibility -const tsJestOptions = { - isolatedModules: true, - tsconfig: 'tsconfig.codeflash.json', -}; - -module.exports = { + # Create a minimal Jest config for TypeScript with ESM support + jest_config_content = f"""// Auto-generated by codeflash for ESM compatibility +module.exports = {{ verbose: true, testEnvironment: 'node', testRegex: '\\\\.(test|spec)\\\\.(js|ts|tsx)$', - testPathIgnorePatterns: ['/dist/', '/node_modules/'], - transform: { - '^.+\\\\.tsx?$': ['ts-jest', tsJestOptions], - }, + testPathIgnorePatterns: ['/dist/'], + // Transform ESM packages that don't work with Jest's default config + // Pattern handles both npm/yarn and pnpm directory structures + transformIgnorePatterns: [ + 'node_modules/(?!(\\\\.pnpm/)?({esm_pattern}))', + ],{transform_config} moduleFileExtensions: ['ts', 'tsx', 'js', 'jsx', 'json', 'node'], -}; +}}; """ try: codeflash_jest_config_path.write_text(jest_config_content) - logger.debug(f"Created {codeflash_jest_config_path} with codeflash tsconfig") + _created_config_files.add(codeflash_jest_config_path) + logger.debug(f"Created {codeflash_jest_config_path} with ESM package support") return codeflash_jest_config_path except Exception as e: logger.warning(f"Failed to create codeflash Jest config: {e}") @@ -323,6 +416,55 @@ def _find_monorepo_root(start_path: Path) -> Path | None: return None +def _get_jest_major_version(project_root: Path) -> int | None: + """Detect the major version of Jest installed in the project. + + Args: + project_root: Root of the project to check. + + Returns: + Major version number (e.g., 29, 30), or None if not detected. + + """ + # First try to check package.json for explicit version + package_json = project_root / "package.json" + if package_json.exists(): + try: + content = json.loads(package_json.read_text()) + deps = {**content.get("devDependencies", {}), **content.get("dependencies", {})} + jest_version = deps.get("jest", "") + # Parse version like "30.0.5", "^30.0.5", "~30.0.5" + if jest_version: + # Strip leading version prefixes (^, ~, =, v) + version_str = jest_version.lstrip("^~=v") + if version_str and version_str[0].isdigit(): + major = version_str.split(".")[0] + if major.isdigit(): + return int(major) + except (json.JSONDecodeError, OSError): + pass + + # Also check monorepo root + monorepo_root = _find_monorepo_root(project_root) + if monorepo_root and monorepo_root != project_root: + monorepo_package = monorepo_root / "package.json" + if monorepo_package.exists(): + try: + content = json.loads(monorepo_package.read_text()) + deps = {**content.get("devDependencies", {}), **content.get("dependencies", {})} + jest_version = deps.get("jest", "") + if jest_version: + version_str = jest_version.lstrip("^~=v") + if version_str and version_str[0].isdigit(): + major = version_str.split(".")[0] + if major.isdigit(): + return int(major) + except (json.JSONDecodeError, OSError): + pass + + return None + + def _find_jest_config(project_root: Path) -> Path | None: """Find Jest configuration file in the project. @@ -535,7 +677,6 @@ def run_jest_behavioral_tests( # Get test files to run test_files = [str(file.instrumented_behavior_file_path) for file in test_paths.test_files] - # Use provided project_root, or detect it as fallback if project_root is None and test_files: first_test_file = Path(test_files[0]) @@ -610,13 +751,25 @@ def run_jest_behavioral_tests( # Configure ESM support if project uses ES Modules _configure_esm_environment(jest_env, effective_cwd) + # Increase Node.js heap size for large TypeScript projects + # Default heap is often not enough for monorepos with many dependencies + existing_node_options = jest_env.get("NODE_OPTIONS", "") + if "--max-old-space-size" not in existing_node_options: + jest_env["NODE_OPTIONS"] = f"{existing_node_options} --max-old-space-size=4096".strip() + logger.debug(f"Running Jest tests with command: {' '.join(jest_cmd)}") + # Calculate subprocess timeout: needs to be much larger than per-test timeout + # to account for Jest startup, TypeScript compilation, module loading, etc. + # Use at least 120 seconds, or 10x the per-test timeout, whichever is larger + subprocess_timeout = max(120, (timeout or 15) * 10, 600) if timeout else 600 + start_time_ns = time.perf_counter_ns() try: run_args = get_cross_platform_subprocess_run_args( - cwd=effective_cwd, env=jest_env, timeout=timeout or 600, check=False, text=True, capture_output=True + cwd=effective_cwd, env=jest_env, timeout=subprocess_timeout, check=False, text=True, capture_output=True ) + logger.debug(f"Jest subprocess timeout: {subprocess_timeout}s (per-test timeout: {timeout}s)") result = subprocess.run(jest_cmd, **run_args) # noqa: PLW1510 # Jest sends console.log output to stderr by default - move it to stdout # so our timing markers (printed via console.log) are in the expected place @@ -634,12 +787,12 @@ def run_jest_behavioral_tests( # This helps debug issues like import errors that cause Jest to fail early if result.returncode != 0 and not result_file_path.exists(): logger.warning( - f"Jest failed with returncode={result.returncode} and no XML output created.\n" + f"Jest failed with returncode={result.returncode}.\n" f"Jest stdout: {result.stdout[:2000] if result.stdout else '(empty)'}\n" f"Jest stderr: {result.stderr[:500] if result.stderr else '(empty)'}" ) except subprocess.TimeoutExpired: - logger.warning(f"Jest tests timed out after {timeout}s") + logger.warning(f"Jest tests timed out after {subprocess_timeout}s") result = subprocess.CompletedProcess(args=jest_cmd, returncode=-1, stdout="", stderr="Test execution timed out") except FileNotFoundError: logger.error("Jest not found. Make sure Jest is installed (npm install jest)") @@ -650,6 +803,8 @@ def run_jest_behavioral_tests( wall_clock_ns = time.perf_counter_ns() - start_time_ns logger.debug(f"Jest behavioral tests completed in {wall_clock_ns / 1e9:.2f}s") + print(result.stdout) + return result_file_path, result, coverage_json_path, None @@ -774,25 +929,28 @@ def run_jest_benchmarking_tests( # Get performance test files test_files = [str(file.benchmarking_file_path) for file in test_paths.test_files if file.benchmarking_file_path] - # Use provided project_root, or detect it as fallback if project_root is None and test_files: first_test_file = Path(test_files[0]) project_root = _find_node_project_root(first_test_file) effective_cwd = project_root if project_root else cwd - logger.debug(f"Jest benchmarking working directory: {effective_cwd}") # Ensure the codeflash npm package is installed _ensure_runtime_files(effective_cwd) - # Build Jest command for performance tests with custom loop runner + # Detect Jest version for logging + jest_major_version = _get_jest_major_version(effective_cwd) + if jest_major_version: + logger.debug(f"Jest {jest_major_version} detected - using loop-runner for batched looping") + + # Build Jest command for performance tests jest_cmd = [ "npx", "jest", "--reporters=default", "--reporters=jest-junit", - "--runInBand", # Ensure serial execution even though runner enforces it + "--runInBand", # Ensure serial execution "--forceExit", "--runner=codeflash/loop-runner", # Use custom loop runner for in-process looping ] @@ -844,9 +1002,25 @@ def run_jest_benchmarking_tests( jest_env["CODEFLASH_PERF_STABILITY_CHECK"] = "true" if stability_check else "false" jest_env["CODEFLASH_LOOP_INDEX"] = "1" # Initial value for compatibility + # Enable console output for timing markers + # Some projects mock console.log in test setup (e.g., based on LOG_LEVEL or DEBUG) + # We need console.log to work for capturePerf timing markers + jest_env["LOG_LEVEL"] = "info" # Disable console.log mocking in projects that check LOG_LEVEL + jest_env["DEBUG"] = "1" # Disable console.log mocking in projects that check DEBUG + + # Debug logging for loop behavior verification (set CODEFLASH_DEBUG_LOOPS=true to enable) + if os.environ.get("CODEFLASH_DEBUG_LOOPS") == "true": + jest_env["CODEFLASH_DEBUG_LOOPS"] = "true" + logger.info("Loop debug logging enabled - will show capturePerf loop details") + # Configure ESM support if project uses ES Modules _configure_esm_environment(jest_env, effective_cwd) + # Increase Node.js heap size for large TypeScript projects + existing_node_options = jest_env.get("NODE_OPTIONS", "") + if "--max-old-space-size" not in existing_node_options: + jest_env["NODE_OPTIONS"] = f"{existing_node_options} --max-old-space-size=4096".strip() + # Total timeout for the entire benchmark run (longer than single-loop timeout) # Account for startup overhead + target duration + buffer total_timeout = max(120, (target_duration_ms // 1000) + 60, timeout or 120) @@ -882,7 +1056,6 @@ def run_jest_benchmarking_tests( wall_clock_seconds = time.time() - total_start_time logger.debug(f"Jest benchmarking completed in {wall_clock_seconds:.2f}s") - return result_file_path, result @@ -985,6 +1158,11 @@ def run_jest_line_profile_tests( # Configure ESM support if project uses ES Modules _configure_esm_environment(jest_env, effective_cwd) + # Increase Node.js heap size for large TypeScript projects + existing_node_options = jest_env.get("NODE_OPTIONS", "") + if "--max-old-space-size" not in existing_node_options: + jest_env["NODE_OPTIONS"] = f"{existing_node_options} --max-old-space-size=4096".strip() + subprocess_timeout = timeout or 600 logger.debug(f"Running Jest line profile tests: {' '.join(jest_cmd)}") diff --git a/codeflash/languages/javascript/treesitter.py b/codeflash/languages/javascript/treesitter.py index 650d899a5..32d2431ac 100644 --- a/codeflash/languages/javascript/treesitter.py +++ b/codeflash/languages/javascript/treesitter.py @@ -69,6 +69,7 @@ class FunctionNode: parent_function: str | None source_text: str doc_start_line: int | None = None # Line where JSDoc comment starts (or None if no JSDoc) + is_exported: bool = False # Whether the function is exported @dataclass @@ -295,6 +296,7 @@ def _extract_function_info( is_generator = False is_method = False is_arrow = node.type == "arrow_function" + is_exported = False # Check for async modifier for child in node.children: @@ -306,6 +308,12 @@ def _extract_function_info( if "generator" in node.type: is_generator = True + # Check if function is exported + # For function_declaration: check if parent is export_statement + # For arrow functions: check if parent variable_declarator's grandparent is export_statement + # For CommonJS: check module.exports = { name } or exports.name = ... + is_exported = self._is_node_exported(node, source_bytes) + # Get function name based on node type if node.type in ("function_declaration", "generator_function_declaration"): name_node = node.child_by_field_name("name") @@ -355,8 +363,157 @@ def _extract_function_info( parent_function=current_function, source_text=source_text, doc_start_line=doc_start_line, + is_exported=is_exported, ) + def _is_node_exported(self, node: Node, source_bytes: bytes | None = None) -> bool: + """Check if a function node is exported. + + Handles various export patterns: + - export function foo() {} + - export const foo = () => {} + - export default function foo() {} + - Class methods in exported classes + - module.exports = { foo } (CommonJS) + - exports.foo = ... (CommonJS) + + Args: + node: The function node to check. + source_bytes: Source code bytes (needed for CommonJS export detection). + + Returns: + True if the function is exported, False otherwise. + + """ + # Check direct parent for export_statement + if node.parent and node.parent.type == "export_statement": + return True + + # For arrow functions and function expressions assigned to variables + # e.g., export const foo = () => {} + if node.type in ("arrow_function", "function_expression", "generator_function"): + parent = node.parent + if parent and parent.type == "variable_declarator": + grandparent = parent.parent + if grandparent and grandparent.type in ("lexical_declaration", "variable_declaration"): + great_grandparent = grandparent.parent + if great_grandparent and great_grandparent.type == "export_statement": + return True + + # For methods in exported classes + if node.type == "method_definition": + # Walk up to find class_declaration + current = node.parent + while current: + if current.type in ("class_declaration", "class"): + # Check if this class is exported via ES module export + if current.parent and current.parent.type == "export_statement": + return True + # Check if class is exported via CommonJS + if source_bytes: + class_name_node = current.child_by_field_name("name") + if class_name_node: + class_name = self.get_node_text(class_name_node, source_bytes) + if self._is_name_in_commonjs_exports(node, class_name, source_bytes): + return True + break + current = current.parent + + # Check CommonJS exports: module.exports = { foo } or exports.foo = ... + if source_bytes: + func_name = self._get_function_name_for_export_check(node, source_bytes) + if func_name and self._is_name_in_commonjs_exports(node, func_name, source_bytes): + return True + + return False + + def _get_function_name_for_export_check(self, node: Node, source_bytes: bytes) -> str | None: + """Get the function name for export checking.""" + if node.type in ("function_declaration", "generator_function_declaration"): + name_node = node.child_by_field_name("name") + if name_node: + return self.get_node_text(name_node, source_bytes) + elif node.type in ("arrow_function", "function_expression", "generator_function"): + # Get name from variable assignment + parent = node.parent + if parent and parent.type == "variable_declarator": + name_node = parent.child_by_field_name("name") + if name_node and name_node.type == "identifier": + return self.get_node_text(name_node, source_bytes) + return None + + def _is_name_in_commonjs_exports(self, node: Node, name: str, source_bytes: bytes) -> bool: + """Check if a name is exported via CommonJS module.exports or exports. + + Handles patterns like: + - module.exports = { foo, bar } + - module.exports = { foo: someFunc } + - exports.foo = ... + - module.exports.foo = ... + + Args: + node: Any node in the tree (used to find the program root). + name: The name to check for in exports. + source_bytes: Source code bytes. + + Returns: + True if the name is in CommonJS exports. + + """ + # Walk up to find program root + root = node + while root.parent: + root = root.parent + + # Search for CommonJS export patterns in program children + for child in root.children: + if child.type == "expression_statement": + # Look for assignment expressions + for expr in child.children: + if expr.type == "assignment_expression": + if self._check_commonjs_assignment_exports(expr, name, source_bytes): + return True + + return False + + def _check_commonjs_assignment_exports(self, node: Node, name: str, source_bytes: bytes) -> bool: + """Check if a CommonJS assignment exports the given name.""" + left_node = node.child_by_field_name("left") + right_node = node.child_by_field_name("right") + + if not left_node or not right_node: + return False + + left_text = self.get_node_text(left_node, source_bytes) + + # Check module.exports = { name, ... } or module.exports = { key: name, ... } + if left_text == "module.exports" and right_node.type == "object": + for child in right_node.children: + if child.type == "shorthand_property_identifier": + # { foo } - shorthand export + if self.get_node_text(child, source_bytes) == name: + return True + elif child.type == "pair": + # { key: value } - check both key and value + key_node = child.child_by_field_name("key") + value_node = child.child_by_field_name("value") + if key_node and self.get_node_text(key_node, source_bytes) == name: + return True + if value_node and value_node.type == "identifier": + if self.get_node_text(value_node, source_bytes) == name: + return True + + # Check module.exports = name (single export) + if left_text == "module.exports" and right_node.type == "identifier": + if self.get_node_text(right_node, source_bytes) == name: + return True + + # Check module.exports.name = ... or exports.name = ... + if left_text in {f"module.exports.{name}", f"exports.{name}"}: + return True + + return False + def _find_preceding_jsdoc(self, node: Node, source_bytes: bytes) -> int | None: """Find JSDoc comment immediately preceding a function node. diff --git a/codeflash/models/models.py b/codeflash/models/models.py index d56672ba8..86aef25c9 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -964,6 +964,28 @@ def total_passed_runtime(self) -> int: [min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()] ) + def effective_loop_count(self) -> int: + """Calculate the effective number of complete loops. + + For consistent behavior across Python and JavaScript tests, this returns + the maximum loop_index seen across all test results. This represents + the number of timing iterations that were performed. + + Note: For JavaScript tests without the loop-runner, each test case may have + different iteration counts due to internal looping in capturePerf. We use + max() to report the highest iteration count achieved. + + :return: The effective loop count, or 0 if no test results. + """ + if not self.test_results: + return 0 + # Get all loop indices from results that have timing data + loop_indices = {result.loop_index for result in self.test_results if result.runtime is not None} + if not loop_indices: + # Fallback: use all loop indices even without runtime + loop_indices = {result.loop_index for result in self.test_results} + return max(loop_indices) if loop_indices else 0 + def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]: map_gen_test_file_to_no_of_tests = Counter() for gen_test_result in self.test_results: diff --git a/codeflash/models/test_type.py b/codeflash/models/test_type.py index e3f196756..154e3f7f2 100644 --- a/codeflash/models/test_type.py +++ b/codeflash/models/test_type.py @@ -10,9 +10,7 @@ class TestType(Enum): INIT_STATE_TEST = 6 def to_name(self) -> str: - if self is TestType.INIT_STATE_TEST: - return "" - return _TO_NAME_MAP[self] + return _TO_NAME_MAP.get(self, "") _TO_NAME_MAP: dict[TestType, str] = { diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 519603416..bb824468e 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -80,6 +80,7 @@ from codeflash.languages.base import Language from codeflash.languages.current import current_language_support, is_typescript from codeflash.languages.javascript.module_system import detect_module_system +from codeflash.languages.javascript.test_runner import clear_created_config_files, get_created_config_files from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table, tree_to_markdown from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId from codeflash.models.ExperimentMetadata import ExperimentMetadata @@ -604,6 +605,11 @@ def generate_and_instrument_tests( f.write(generated_test.instrumented_behavior_test_source) logger.debug(f"[PIPELINE] Wrote behavioral test to {generated_test.behavior_file_path}") + # Save perf test source for debugging + debug_file_path = get_run_tmp_file(Path("perf_test_debug.test.ts")) + with debug_file_path.open("w", encoding="utf-8") as debug_f: + debug_f.write(generated_test.instrumented_perf_test_source) + with generated_test.perf_file_path.open("w", encoding="utf8") as f: f.write(generated_test.instrumented_perf_test_source) logger.debug(f"[PIPELINE] Wrote perf test to {generated_test.perf_file_path}") @@ -2079,7 +2085,7 @@ def process_review( formatted_generated_test = format_generated_code(concolic_test_str, self.args.formatter_cmds) generated_tests_str += f"```{code_lang}\n{formatted_generated_test}\n```\n\n" - existing_tests, replay_tests, _ = existing_tests_source_for( + existing_tests, replay_tests, _concolic_tests = existing_tests_source_for( self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), function_to_all_tests, test_cfg=self.test_cfg, @@ -2394,7 +2400,7 @@ def establish_original_code_baseline( if not success: return Failure("Failed to establish a baseline for the original code.") - loop_count = max([int(result.loop_index) for result in benchmarking_results.test_results]) + loop_count = benchmarking_results.effective_loop_count() logger.info( f"h3|⌚ Original code summed runtime measured over '{loop_count}' loop{'s' if loop_count > 1 else ''}: " f"'{humanize_runtime(total_timing)}' per full loop" @@ -2617,11 +2623,10 @@ def run_optimized_candidate( self.write_code_and_helpers( candidate_fto_code, candidate_helper_code, self.function_to_optimize.file_path ) - loop_count = ( - max(all_loop_indices) - if (all_loop_indices := {result.loop_index for result in candidate_benchmarking_results.test_results}) - else 0 - ) + # Use effective_loop_count which represents the minimum number of timing samples + # across all test cases. This is more accurate for JavaScript tests where + # capturePerf does internal looping with potentially different iteration counts per test. + loop_count = candidate_benchmarking_results.effective_loop_count() if (total_candidate_timing := candidate_benchmarking_results.total_passed_runtime()) == 0: logger.warning("The overall test runtime of the optimized function is 0, couldn't run tests.") @@ -2817,6 +2822,13 @@ def cleanup_generated_files(self) -> None: paths_to_cleanup.append(test_file.instrumented_behavior_file_path) paths_to_cleanup.append(test_file.benchmarking_file_path) + # Clean up created config files (jest.codeflash.config.js, tsconfig.codeflash.json) + config_files = get_created_config_files() + if config_files: + paths_to_cleanup.extend(config_files) + logger.debug(f"Cleaning up {len(config_files)} codeflash config file(s)") + clear_created_config_files() + cleanup_paths(paths_to_cleanup) def get_test_env( diff --git a/codeflash/verification/coverage_utils.py b/codeflash/verification/coverage_utils.py index f0678454e..20535b9f7 100644 --- a/codeflash/verification/coverage_utils.py +++ b/codeflash/verification/coverage_utils.py @@ -58,7 +58,9 @@ def load_from_jest_json( source_path_str = str(source_code_path.resolve()) for file_path, file_data in coverage_data.items(): - if file_path == source_path_str or file_path.endswith(source_code_path.name): + # Match exact path or path ending with full relative path from src/ + # Avoid matching files with same name in different directories (e.g., db/utils.ts vs utils/utils.ts) + if file_path == source_path_str or file_path.endswith(str(source_code_path)): file_coverage = file_data break diff --git a/codeflash/verification/verification_utils.py b/codeflash/verification/verification_utils.py index 76131a78c..c567e6a9a 100644 --- a/codeflash/verification/verification_utils.py +++ b/codeflash/verification/verification_utils.py @@ -23,12 +23,12 @@ def get_test_file_path( # For JavaScript/TypeScript, place generated tests in a subdirectory that matches # Vitest/Jest include patterns (e.g., test/**/*.test.ts) - if is_javascript(): - # For monorepos, first try to find the package directory from the source file path - # e.g., packages/workflow/src/utils.ts -> packages/workflow/test/codeflash-generated/ - package_test_dir = _find_js_package_test_dir(test_dir, source_file_path) - if package_test_dir: - test_dir = package_test_dir + # if is_javascript(): + # # For monorepos, first try to find the package directory from the source file path + # # e.g., packages/workflow/src/utils.ts -> packages/workflow/test/codeflash-generated/ + # package_test_dir = _find_js_package_test_dir(test_dir, source_file_path) + # if package_test_dir: + # test_dir = package_test_dir path = test_dir / f"test_{function_name}__{test_type}_test_{iteration}{extension}" if path.exists(): diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index 4d49c94c2..78bd2e4ab 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -42,7 +42,20 @@ def generate_tests( source_file = Path(function_to_optimize.file_path) project_module_system = detect_module_system(test_cfg.tests_project_rootdir, source_file) - logger.debug(f"Detected module system: {project_module_system}") + + # For JavaScript, calculate the correct import path from the actual test location + # (test_path) to the source file, not from tests_root + import os + + source_file_abs = source_file.resolve().with_suffix("") + test_dir_abs = test_path.resolve().parent + # Compute relative path from test directory to source file + rel_import_path = os.path.relpath(str(source_file_abs), str(test_dir_abs)) + # Ensure path starts with ./ or ../ for JavaScript/TypeScript imports + if not rel_import_path.startswith("../"): + rel_import_path = f"./{rel_import_path}" + # Keep as string since Path() normalizes away the ./ prefix + module_path = rel_import_path response = aiservice_client.generate_regression_tests( source_code_being_tested=source_code_being_tested, @@ -66,7 +79,8 @@ def generate_tests( if is_javascript(): from codeflash.languages.javascript.instrument import ( TestingMode, - fix_import_path_for_test_location, + fix_imports_inside_test_blocks, + fix_jest_mock_paths, instrument_generated_js_test, validate_and_fix_import_style, ) @@ -77,10 +91,12 @@ def generate_tests( source_file = Path(function_to_optimize.file_path) - # Fix import paths to be relative to test file location - # AI may generate imports like 'apps/web/app/file' instead of '../../app/file' - generated_test_source = fix_import_path_for_test_location( - generated_test_source, source_file, test_path, module_path + # Fix import statements that appear inside test blocks (invalid JS syntax) + generated_test_source = fix_imports_inside_test_blocks(generated_test_source) + + # Fix relative paths in jest.mock() calls + generated_test_source = fix_jest_mock_paths( + generated_test_source, test_path, source_file, test_cfg.tests_project_rootdir ) # Validate and fix import styles (default vs named exports) diff --git a/packages/codeflash/runtime/capture.js b/packages/codeflash/runtime/capture.js index 616e2907c..0fdcc5784 100644 --- a/packages/codeflash/runtime/capture.js +++ b/packages/codeflash/runtime/capture.js @@ -87,6 +87,8 @@ if (!process[PERF_STATE_KEY]) { shouldStop: false, // Flag to stop all further looping currentBatch: 0, // Current batch number (incremented by runner) invocationLoopCounts: {}, // Track loops per invocation: {invocationKey: loopCount} + invocationRuntimes: {}, // Track runtimes per invocation for stability: {invocationKey: [runtimes]} + stableInvocations: {}, // Invocations that have reached stability: {invocationKey: true} }; } const sharedPerfState = process[PERF_STATE_KEY]; @@ -98,10 +100,10 @@ const sharedPerfState = process[PERF_STATE_KEY]; function checkSharedTimeLimit() { if (sharedPerfState.shouldStop) return true; if (sharedPerfState.startTime === null) { - sharedPerfState.startTime = Date.now(); + sharedPerfState.startTime = _ORIGINAL_DATE_NOW(); return false; } - const elapsed = Date.now() - sharedPerfState.startTime; + const elapsed = _ORIGINAL_DATE_NOW() - sharedPerfState.startTime; if (elapsed >= getPerfTargetDurationMs() && sharedPerfState.totalLoopsCompleted >= getPerfMinLoops()) { sharedPerfState.shouldStop = true; return true; @@ -111,25 +113,33 @@ function checkSharedTimeLimit() { /** * Get the current loop index for a specific invocation. - * Each invocation tracks its own loop count independently within a batch. - * The actual loop index is computed as: (batch - 1) * BATCH_SIZE + localIndex - * This ensures continuous loop indices even when Jest resets module state. + * The loop index represents how many times ALL test files have been run through. + * This is the batch count from the loop-runner. * @param {string} invocationKey - Unique key for this test invocation - * @returns {number} The next global loop index for this invocation + * @returns {number} The current batch number (loop index) */ function getInvocationLoopIndex(invocationKey) { - // Track local loop count within this batch (starts at 0) + // Track local loop count for stopping logic (increments on each call) if (!sharedPerfState.invocationLoopCounts[invocationKey]) { sharedPerfState.invocationLoopCounts[invocationKey] = 0; } - const localIndex = ++sharedPerfState.invocationLoopCounts[invocationKey]; + ++sharedPerfState.invocationLoopCounts[invocationKey]; - // Calculate global loop index using batch number from environment - // PERF_CURRENT_BATCH is 1-based (set by loop-runner before each batch) - const currentBatch = parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '1', 10); - const globalIndex = (currentBatch - 1) * getPerfBatchSize() + localIndex; + // Return the batch number as the loop index for timing markers + // This represents how many times all test files have been run through + return parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '1', 10); +} - return globalIndex; +/** + * Get the total number of iterations for a specific invocation. + * Used for stopping logic to check against max loop count. + * @param {string} invocationKey - Unique key for this test invocation + * @returns {number} Total iterations across all batches for this invocation + */ +function getTotalIterations(invocationKey) { + const localCount = sharedPerfState.invocationLoopCounts[invocationKey] || 0; + const currentBatch = parseInt(process.env.CODEFLASH_PERF_CURRENT_BATCH || '1', 10); + return (currentBatch - 1) * getPerfBatchSize() + localCount; } /** @@ -164,6 +174,8 @@ function createSeededRandom(seed) { return ((t ^ t >>> 14) >>> 0) / 4294967296; }; } +let _ORIGINAL_DATE = Date +let _ORIGINAL_DATE_NOW = Date.now // Override non-deterministic APIs with seeded versions if seed is provided // NOTE: We do NOT seed performance.now() or process.hrtime() as those are used @@ -176,8 +188,8 @@ if (RANDOM_SEED !== 0) { // Seed Date.now() and new Date() - use fixed base timestamp that increments const SEEDED_BASE_TIME = 1700000000000; // Nov 14, 2023 - fixed reference point let dateOffset = 0; - const OriginalDate = Date; - const originalDateNow = Date.now; + _ORIGINAL_DATE = Date; + _ORIGINAL_DATE_NOW = Date.now; Date.now = function() { return SEEDED_BASE_TIME + (dateOffset++); @@ -187,15 +199,15 @@ if (RANDOM_SEED !== 0) { function SeededDate(...args) { if (args.length === 0) { // No arguments: use seeded current time - return new OriginalDate(SEEDED_BASE_TIME + (dateOffset++)); + return new _ORIGINAL_DATE(SEEDED_BASE_TIME + (dateOffset++)); } // With arguments: use original behavior - return new OriginalDate(...args); + return new _ORIGINAL_DATE(...args); } - SeededDate.prototype = OriginalDate.prototype; + SeededDate.prototype = _ORIGINAL_DATE.prototype; SeededDate.now = Date.now; - SeededDate.parse = OriginalDate.parse; - SeededDate.UTC = OriginalDate.UTC; + SeededDate.parse = _ORIGINAL_DATE.parse; + SeededDate.UTC = _ORIGINAL_DATE.UTC; global.Date = SeededDate; // Seed crypto.randomUUID() and crypto.getRandomValues() @@ -265,26 +277,40 @@ const results = []; let db = null; /** - * Check if performance has stabilized (for internal looping). - * Matches Python's pytest_plugin.should_stop() logic. + * Check if performance has stabilized, allowing early stopping of benchmarks. + * Matches Python's pytest_plugin.should_stop() logic for consistency. + * + * Performance is considered stable when BOTH conditions are met: + * 1. CENTER: All recent measurements are within Β±10% of the median + * 2. SPREAD: The range (max-min) is within 10% of the minimum + * + * @param {Array} runtimes - Array of runtime measurements in microseconds + * @param {number} window - Number of recent measurements to check + * @param {number} minWindowSize - Minimum samples required before checking + * @returns {boolean} True if performance has stabilized */ function shouldStopStability(runtimes, window, minWindowSize) { if (runtimes.length < window || runtimes.length < minWindowSize) { return false; } + const recent = runtimes.slice(-window); const recentSorted = [...recent].sort((a, b) => a - b); const mid = Math.floor(window / 2); const median = window % 2 ? recentSorted[mid] : (recentSorted[mid - 1] + recentSorted[mid]) / 2; + // Check CENTER: all recent points must be close to median for (const r of recent) { if (Math.abs(r - median) / median > STABILITY_CENTER_TOLERANCE) { return false; } } + + // Check SPREAD: range must be small relative to minimum const rMin = recentSorted[0]; const rMax = recentSorted[recentSorted.length - 1]; if (rMin === 0) return false; + return (rMax - rMin) / rMin <= STABILITY_SPREAD_TOLERANCE; } @@ -673,17 +699,32 @@ function capturePerf(funcName, lineId, fn, ...args) { ? (hasExternalLoopRunner ? getPerfBatchSize() : getPerfLoopCount()) : 1; + // Initialize runtime tracking for this invocation if needed + if (!sharedPerfState.invocationRuntimes[invocationKey]) { + sharedPerfState.invocationRuntimes[invocationKey] = []; + } + const runtimes = sharedPerfState.invocationRuntimes[invocationKey]; + + // Calculate stability window size based on collected runtimes + const getStabilityWindow = () => Math.max(getPerfMinLoops(), Math.ceil(runtimes.length * STABILITY_WINDOW_SIZE)); + for (let batchIndex = 0; batchIndex < batchSize; batchIndex++) { // Check shared time limit BEFORE each iteration if (shouldLoop && checkSharedTimeLimit()) { break; } - // Get the global loop index for this invocation (increments across batches) + // Check if this invocation has already reached stability + if (getPerfStabilityCheck() && sharedPerfState.stableInvocations[invocationKey]) { + break; + } + + // Get the loop index (batch number) for timing markers const loopIndex = getInvocationLoopIndex(invocationKey); // Check if we've exceeded max loops for this invocation - if (loopIndex > getPerfLoopCount()) { + const totalIterations = getTotalIterations(invocationKey); + if (totalIterations > getPerfLoopCount()) { break; } @@ -703,23 +744,17 @@ function capturePerf(funcName, lineId, fn, ...args) { const endTime = getTimeNs(); durationNs = getDurationNs(startTime, endTime); - // Handle promises - for async functions, run once and return + // Handle promises - for async functions, we need to handle looping differently + // Since we can't use await in the sync loop, delegate to async helper if (lastReturnValue instanceof Promise) { - return lastReturnValue.then( - (resolved) => { - const asyncEndTime = getTimeNs(); - const asyncDurationNs = getDurationNs(startTime, asyncEndTime); - console.log(`!######${testStdoutTag}:${asyncDurationNs}######!`); - sharedPerfState.totalLoopsCompleted++; - return resolved; - }, - (err) => { - const asyncEndTime = getTimeNs(); - const asyncDurationNs = getDurationNs(startTime, asyncEndTime); - console.log(`!######${testStdoutTag}:${asyncDurationNs}######!`); - sharedPerfState.totalLoopsCompleted++; - throw err; - } + // For async functions, delegate to the async looping helper + // Pass along all the context needed for continued looping + return _capturePerfAsync( + funcName, lineId, fn, args, + lastReturnValue, startTime, testStdoutTag, + safeModulePath, testClassName, safeTestFunctionName, + invocationKey, runtimes, batchSize, batchIndex, + shouldLoop, getStabilityWindow ); } @@ -735,6 +770,20 @@ function capturePerf(funcName, lineId, fn, ...args) { // Update shared loop counter sharedPerfState.totalLoopsCompleted++; + // Track runtime for stability check (convert to microseconds) + if (durationNs > 0) { + runtimes.push(durationNs / 1000); + } + + // Check stability after accumulating enough samples + if (getPerfStabilityCheck() && runtimes.length >= getPerfMinLoops()) { + const window = getStabilityWindow(); + if (shouldStopStability(runtimes, window, getPerfMinLoops())) { + sharedPerfState.stableInvocations[invocationKey] = true; + break; + } + } + // If we had an error, stop looping if (lastError) { break; @@ -751,6 +800,123 @@ function capturePerf(funcName, lineId, fn, ...args) { return lastReturnValue; } +/** + * Helper to record async timing and update state. + * @private + */ +function _recordAsyncTiming(startTime, testStdoutTag, durationNs, runtimes) { + console.log(`!######${testStdoutTag}:${durationNs}######!`); + sharedPerfState.totalLoopsCompleted++; + if (durationNs > 0) { + runtimes.push(durationNs / 1000); + } +} + +/** + * Async helper for capturePerf to handle async function looping. + * This function awaits promises and continues the benchmark loop properly. + * + * @private + * @param {string} funcName - Name of the function being benchmarked + * @param {string} lineId - Line identifier for this capture point + * @param {Function} fn - The async function to benchmark + * @param {Array} args - Arguments to pass to fn + * @param {Promise} firstPromise - The first promise that was already started + * @param {number} firstStartTime - Start time of the first execution + * @param {string} firstTestStdoutTag - Timing marker tag for the first execution + * @param {string} safeModulePath - Sanitized module path + * @param {string|null} testClassName - Test class name (if any) + * @param {string} safeTestFunctionName - Sanitized test function name + * @param {string} invocationKey - Unique key for this invocation + * @param {Array} runtimes - Array to collect runtimes for stability checking + * @param {number} batchSize - Number of iterations per batch + * @param {number} startBatchIndex - Index where async looping started + * @param {boolean} shouldLoop - Whether to continue looping + * @param {Function} getStabilityWindow - Function to get stability window size + * @returns {Promise} The last return value from fn + */ +async function _capturePerfAsync( + funcName, lineId, fn, args, + firstPromise, firstStartTime, firstTestStdoutTag, + safeModulePath, testClassName, safeTestFunctionName, + invocationKey, runtimes, batchSize, startBatchIndex, + shouldLoop, getStabilityWindow +) { + let lastReturnValue; + let lastError = null; + + // Handle the first promise that was already started + try { + lastReturnValue = await firstPromise; + const asyncEndTime = getTimeNs(); + const asyncDurationNs = getDurationNs(firstStartTime, asyncEndTime); + _recordAsyncTiming(firstStartTime, firstTestStdoutTag, asyncDurationNs, runtimes); + } catch (err) { + const asyncEndTime = getTimeNs(); + const asyncDurationNs = getDurationNs(firstStartTime, asyncEndTime); + _recordAsyncTiming(firstStartTime, firstTestStdoutTag, asyncDurationNs, runtimes); + lastError = err; + // Don't throw yet - we want to record the timing first + } + + // If first iteration failed, stop and throw + if (lastError) { + throw lastError; + } + + // Continue looping for remaining iterations + for (let batchIndex = startBatchIndex + 1; batchIndex < batchSize; batchIndex++) { + // Check exit conditions before starting next iteration + if (shouldLoop && checkSharedTimeLimit()) { + break; + } + + if (getPerfStabilityCheck() && sharedPerfState.stableInvocations[invocationKey]) { + break; + } + + // Get the loop index (batch number) for timing markers + const loopIndex = getInvocationLoopIndex(invocationKey); + + // Check if we've exceeded max loops for this invocation + const totalIterations = getTotalIterations(invocationKey); + if (totalIterations > getPerfLoopCount()) { + break; + } + + // Generate timing marker identifiers + const testId = `${safeModulePath}:${testClassName}:${safeTestFunctionName}:${lineId}:${loopIndex}`; + const invocationIndex = getInvocationIndex(testId); + const invocationId = `${lineId}_${invocationIndex}`; + const testStdoutTag = `${safeModulePath}:${testClassName ? testClassName + '.' : ''}${safeTestFunctionName}:${funcName}:${loopIndex}:${invocationId}`; + + // Execute and time the function + try { + const startTime = getTimeNs(); + lastReturnValue = await fn(...args); + const endTime = getTimeNs(); + const durationNs = getDurationNs(startTime, endTime); + + _recordAsyncTiming(startTime, testStdoutTag, durationNs, runtimes); + + // Check if we've reached performance stability + if (getPerfStabilityCheck() && runtimes.length >= getPerfMinLoops()) { + const window = getStabilityWindow(); + if (shouldStopStability(runtimes, window, getPerfMinLoops())) { + sharedPerfState.stableInvocations[invocationKey] = true; + break; + } + } + } catch (e) { + lastError = e; + break; + } + } + + if (lastError) throw lastError; + return lastReturnValue; +} + /** * Capture multiple invocations for benchmarking. * @@ -789,7 +955,7 @@ function writeResults() { const output = { version: '1.0.0', loopIndex: LOOP_INDEX, - timestamp: Date.now(), + timestamp: _ORIGINAL_DATE_NOW(), results }; fs.writeFileSync(jsonPath, JSON.stringify(output, null, 2)); @@ -806,6 +972,8 @@ function resetPerfState() { sharedPerfState.startTime = null; sharedPerfState.totalLoopsCompleted = 0; sharedPerfState.shouldStop = false; + sharedPerfState.invocationRuntimes = {}; + sharedPerfState.stableInvocations = {}; } /** diff --git a/packages/codeflash/runtime/loop-runner.js b/packages/codeflash/runtime/loop-runner.js index 6bfde0c4c..c6d476f1f 100644 --- a/packages/codeflash/runtime/loop-runner.js +++ b/packages/codeflash/runtime/loop-runner.js @@ -24,6 +24,8 @@ * NOTE: This runner requires jest-runner to be installed in your project. * It is a Jest-specific feature and does not work with Vitest. * For Vitest projects, capturePerf() does all loops internally in a single call. + * + * Compatibility: Works with Jest 29.x and Jest 30.x */ 'use strict'; @@ -32,10 +34,26 @@ const { createRequire } = require('module'); const path = require('path'); const fs = require('fs'); +/** + * Validates that a jest-runner path is valid by checking for package.json. + * @param {string} jestRunnerPath - Path to check + * @returns {boolean} True if valid jest-runner package + */ +function isValidJestRunnerPath(jestRunnerPath) { + if (!fs.existsSync(jestRunnerPath)) { + return false; + } + const packageJsonPath = path.join(jestRunnerPath, 'package.json'); + return fs.existsSync(packageJsonPath); +} + /** * Resolve jest-runner with monorepo support. * Uses CODEFLASH_MONOREPO_ROOT environment variable if available, * otherwise walks up the directory tree looking for node_modules/jest-runner. + * + * @returns {string} Path to jest-runner package + * @throws {Error} If jest-runner cannot be found */ function resolveJestRunner() { // Try standard resolution first (works in simple projects) @@ -49,11 +67,8 @@ function resolveJestRunner() { const monorepoRoot = process.env.CODEFLASH_MONOREPO_ROOT; if (monorepoRoot) { const jestRunnerPath = path.join(monorepoRoot, 'node_modules', 'jest-runner'); - if (fs.existsSync(jestRunnerPath)) { - const packageJsonPath = path.join(jestRunnerPath, 'package.json'); - if (fs.existsSync(packageJsonPath)) { - return jestRunnerPath; - } + if (isValidJestRunnerPath(jestRunnerPath)) { + return jestRunnerPath; } } @@ -69,11 +84,8 @@ function resolveJestRunner() { // Try node_modules/jest-runner at this level const jestRunnerPath = path.join(currentDir, 'node_modules', 'jest-runner'); - if (fs.existsSync(jestRunnerPath)) { - const packageJsonPath = path.join(jestRunnerPath, 'package.json'); - if (fs.existsSync(packageJsonPath)) { - return jestRunnerPath; - } + if (isValidJestRunnerPath(jestRunnerPath)) { + return jestRunnerPath; } // Check if this is a workspace root (has monorepo markers) @@ -89,18 +101,53 @@ function resolveJestRunner() { currentDir = path.dirname(currentDir); } - throw new Error('jest-runner not found'); + throw new Error( + 'jest-runner not found. Please install jest-runner in your project: npm install --save-dev jest-runner' + ); } -// Try to load jest-runner - it's a peer dependency that must be installed by the user +/** + * Jest runner components - loaded dynamically from project's node_modules. + * This ensures we use the same version that the project uses. + * + * Jest 30+ uses TestRunner class with event-based architecture. + * Jest 29 uses runTest function for direct test execution. + */ +let TestRunner; let runTest; let jestRunnerAvailable = false; +let jestVersion = 0; try { const jestRunnerPath = resolveJestRunner(); const internalRequire = createRequire(jestRunnerPath); - runTest = internalRequire('./runTest').default; - jestRunnerAvailable = true; + + // Try to get the TestRunner class (Jest 30+) + const jestRunner = internalRequire(jestRunnerPath); + TestRunner = jestRunner.default || jestRunner.TestRunner; + + if (TestRunner && TestRunner.prototype && typeof TestRunner.prototype.runTests === 'function') { + // Jest 30+ - use TestRunner class with event emitter pattern + jestVersion = 30; + jestRunnerAvailable = true; + } else { + // Try Jest 29 style import + try { + runTest = internalRequire('./runTest').default; + if (typeof runTest === 'function') { + // Jest 29 - use direct runTest function + jestVersion = 29; + jestRunnerAvailable = true; + } + } catch (e29) { + // Neither Jest 29 nor 30 style import worked + const errorMsg = `Found jest-runner at ${jestRunnerPath} but could not load it. ` + + `This may indicate an unsupported Jest version. ` + + `Supported versions: Jest 29.x and Jest 30.x`; + console.error(errorMsg); + jestRunnerAvailable = false; + } + } } catch (e) { // jest-runner not installed - this is expected for Vitest projects // The runner will throw a helpful error if someone tries to use it without jest-runner @@ -167,6 +214,9 @@ function deepCopy(obj, seen = new WeakMap()) { /** * Codeflash Loop Runner with Batched Looping + * + * For Jest 30+, extends the TestRunner class directly. + * For Jest 29, uses the runTest function import. */ class CodeflashLoopRunner { constructor(globalConfig, context) { @@ -175,12 +225,24 @@ class CodeflashLoopRunner { 'codeflash/loop-runner requires jest-runner to be installed.\n' + 'Please install it: npm install --save-dev jest-runner\n\n' + 'If you are using Vitest, the loop-runner is not needed - ' + - 'Vitest projects use external looping handled by the Python runner.' + 'Vitest projects use internal looping handled by capturePerf().' ); } + this._globalConfig = globalConfig; this._context = context || {}; this._eventEmitter = new SimpleEventEmitter(); + + // For Jest 30+, create an instance of the base TestRunner for delegation + if (jestVersion >= 30) { + if (!TestRunner) { + throw new Error( + `Jest ${jestVersion} detected but TestRunner class not available. ` + + `This indicates an internal error in loop-runner initialization.` + ); + } + this._baseRunner = new TestRunner(globalConfig, context); + } } get supportsEventEmitters() { @@ -196,7 +258,17 @@ class CodeflashLoopRunner { } /** - * Run tests with batched looping for fair distribution. + * Run tests with batched looping for fair distribution across all test invocations. + * + * This implements the batched looping strategy: + * Batch 1: Test1(N loops) β†’ Test2(N loops) β†’ Test3(N loops) + * Batch 2: Test1(N loops) β†’ Test2(N loops) β†’ Test3(N loops) + * ...until time budget exhausted or max batches reached + * + * @param {Array} tests - Jest test objects to run + * @param {Object} watcher - Jest watcher for interrupt handling + * @param {Object} options - Jest runner options + * @returns {Promise} */ async runTests(tests, watcher, options) { const startTime = Date.now(); @@ -204,59 +276,51 @@ class CodeflashLoopRunner { let hasFailure = false; let allConsoleOutput = ''; - // Import shared state functions from capture module - // We need to do this dynamically since the module may be reloaded - let checkSharedTimeLimit; - let incrementBatch; - try { - const capture = require('codeflash'); - checkSharedTimeLimit = capture.checkSharedTimeLimit; - incrementBatch = capture.incrementBatch; - } catch (e) { - // Fallback if codeflash module not available - checkSharedTimeLimit = () => { - const elapsed = Date.now() - startTime; - return elapsed >= TARGET_DURATION_MS && batchCount >= MIN_BATCHES; - }; - incrementBatch = () => {}; - } + // Time limit check - must use local time tracking because Jest runs tests + // in isolated worker processes where shared state from capture.js isn't accessible + const checkTimeLimit = () => { + const elapsed = Date.now() - startTime; + return elapsed >= TARGET_DURATION_MS && batchCount >= MIN_BATCHES; + }; // Batched looping: run all test files multiple times while (batchCount < MAX_BATCHES) { batchCount++; // Check time limit BEFORE each batch - if (batchCount > MIN_BATCHES && checkSharedTimeLimit()) { + if (batchCount > MIN_BATCHES && checkTimeLimit()) { + console.log(`[codeflash] Time limit reached after ${batchCount - 1} batches (${Date.now() - startTime}ms elapsed)`); break; } // Check if interrupted if (watcher.isInterrupted()) { + console.log(`[codeflash] Watcher is interrupted`) break; } - // Increment batch counter in shared state and set env var - // The env var persists across Jest module resets, ensuring continuous loop indices - incrementBatch(); + // Set env var for batch number - persists across Jest module resets process.env.CODEFLASH_PERF_CURRENT_BATCH = String(batchCount); // Run all test files in this batch - const batchResult = await this._runAllTestsOnce(tests, watcher); + const batchResult = await this._runAllTestsOnce(tests, watcher, options); allConsoleOutput += batchResult.consoleOutput; - if (batchResult.hasFailure) { - hasFailure = true; - break; - } + // if (batchResult.hasFailure) { + // hasFailure = true; + // break; + // } // Check time limit AFTER each batch - if (checkSharedTimeLimit()) { + if (checkTimeLimit()) { + console.log(`[codeflash] Time limit reached after ${batchCount} batches (${Date.now() - startTime}ms elapsed)`); break; } } const totalTimeMs = Date.now() - startTime; + console.log(`[codeflash] now: ${Date.now()}`) // Output all collected console logs - this is critical for timing marker extraction // The console output contains the !######...######! timing markers from capturePerf if (allConsoleOutput) { @@ -268,8 +332,74 @@ class CodeflashLoopRunner { /** * Run all test files once (one batch). + * Uses different approaches for Jest 29 vs Jest 30. + */ + async _runAllTestsOnce(tests, watcher, options) { + if (jestVersion >= 30) { + return this._runAllTestsOnceJest30(tests, watcher, options); + } else { + return this._runAllTestsOnceJest29(tests, watcher); + } + } + + /** + * Jest 30+ implementation - delegates to base TestRunner and collects results. + */ + async _runAllTestsOnceJest30(tests, watcher, options) { + let hasFailure = false; + let allConsoleOutput = ''; + + // For Jest 30, we need to collect results through event listeners + const resultsCollector = []; + + // Subscribe to events from the base runner + const unsubscribeSuccess = this._baseRunner.on('test-file-success', (testData) => { + const [test, result] = testData; + resultsCollector.push({ test, result, success: true }); + + if (result && result.console && Array.isArray(result.console)) { + allConsoleOutput += result.console.map(e => e.message || '').join('\n') + '\n'; + } + + if (result && result.numFailingTests > 0) { + hasFailure = true; + } + + // Forward to our event emitter + this._eventEmitter.emit('test-file-success', testData); + }); + + const unsubscribeFailure = this._baseRunner.on('test-file-failure', (testData) => { + const [test, error] = testData; + resultsCollector.push({ test, error, success: false }); + hasFailure = true; + + // Forward to our event emitter + this._eventEmitter.emit('test-file-failure', testData); + }); + + const unsubscribeStart = this._baseRunner.on('test-file-start', (testData) => { + // Forward to our event emitter + this._eventEmitter.emit('test-file-start', testData); + }); + + try { + // Run tests using the base runner (always serial for benchmarking) + await this._baseRunner.runTests(tests, watcher, { ...options, serial: true }); + } finally { + // Cleanup subscriptions + if (typeof unsubscribeSuccess === 'function') unsubscribeSuccess(); + if (typeof unsubscribeFailure === 'function') unsubscribeFailure(); + if (typeof unsubscribeStart === 'function') unsubscribeStart(); + } + + return { consoleOutput: allConsoleOutput, hasFailure }; + } + + /** + * Jest 29 implementation - uses direct runTest import. */ - async _runAllTestsOnce(tests, watcher) { + async _runAllTestsOnceJest29(tests, watcher) { let hasFailure = false; let allConsoleOutput = ''; diff --git a/tests/test_javascript_function_discovery.py b/tests/test_javascript_function_discovery.py index 9a39086a8..cf76bee2d 100644 --- a/tests/test_javascript_function_discovery.py +++ b/tests/test_javascript_function_discovery.py @@ -23,7 +23,7 @@ def test_simple_function_discovery(self, tmp_path): """Test discovering a simple JavaScript function with return statement.""" js_file = tmp_path / "simple.js" js_file.write_text(""" -function add(a, b) { +export function add(a, b) { return a + b; } """) @@ -39,15 +39,15 @@ def test_multiple_functions_discovery(self, tmp_path): """Test discovering multiple JavaScript functions.""" js_file = tmp_path / "multiple.js" js_file.write_text(""" -function add(a, b) { +export function add(a, b) { return a + b; } -function multiply(a, b) { +export function multiply(a, b) { return a * b; } -function divide(a, b) { +export function divide(a, b) { return a / b; } """) @@ -61,11 +61,11 @@ def test_function_without_return_excluded(self, tmp_path): """Test that functions without return statements are excluded.""" js_file = tmp_path / "no_return.js" js_file.write_text(""" -function withReturn() { +export function withReturn() { return 42; } -function withoutReturn() { +export function withoutReturn() { console.log("hello"); } """) @@ -78,11 +78,11 @@ def test_arrow_function_discovery(self, tmp_path): """Test discovering arrow functions with explicit return.""" js_file = tmp_path / "arrow.js" js_file.write_text(""" -const add = (a, b) => { +export const add = (a, b) => { return a + b; }; -const multiply = (a, b) => a * b; +export const multiply = (a, b) => a * b; """) functions = find_all_functions_in_file(js_file) @@ -95,7 +95,7 @@ def test_class_method_discovery(self, tmp_path): """Test discovering methods inside a JavaScript class.""" js_file = tmp_path / "class.js" js_file.write_text(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -120,11 +120,11 @@ def test_async_function_discovery(self, tmp_path): """Test discovering async JavaScript functions.""" js_file = tmp_path / "async.js" js_file.write_text(""" -async function fetchData(url) { +export async function fetchData(url) { return await fetch(url); } -function syncFunc() { +export function syncFunc() { return 42; } """) @@ -141,7 +141,7 @@ def test_nested_function_excluded(self, tmp_path): """Test that nested functions are handled correctly.""" js_file = tmp_path / "nested.js" js_file.write_text(""" -function outer() { +export function outer() { function inner() { return 1; } @@ -158,11 +158,11 @@ def test_jsx_file_discovery(self, tmp_path): """Test discovering functions in JSX files.""" jsx_file = tmp_path / "component.jsx" jsx_file.write_text(""" -function Button({ onClick }) { +export function Button({ onClick }) { return ; } -function formatText(text) { +export function formatText(text) { return text.toUpperCase(); } """) @@ -176,7 +176,7 @@ def test_invalid_javascript_returns_empty(self, tmp_path): """Test that invalid JavaScript code returns empty results.""" js_file = tmp_path / "invalid.js" js_file.write_text(""" -function broken( { +export function broken( { return 42; } """) @@ -189,11 +189,11 @@ def test_function_line_numbers(self, tmp_path): """Test that function line numbers are correctly detected.""" js_file = tmp_path / "lines.js" js_file.write_text(""" -function firstFunc() { +export function firstFunc() { return 1; } -function secondFunc() { +export function secondFunc() { return 2; } """) @@ -217,7 +217,7 @@ def test_filter_functions_includes_javascript(self, tmp_path): """Test that filter_functions correctly includes JavaScript files.""" js_file = tmp_path / "module.js" js_file.write_text(""" -function add(a, b) { +export function add(a, b) { return a + b; } """) @@ -240,7 +240,7 @@ def test_filter_excludes_test_directory(self, tmp_path): tests_dir.mkdir() test_file = tests_dir / "test_module.test.js" test_file.write_text(""" -function testHelper() { +export function testHelper() { return 42; } """) @@ -260,7 +260,7 @@ def test_filter_excludes_ignored_paths(self, tmp_path): ignored_dir.mkdir() js_file = ignored_dir / "ignored_module.js" js_file.write_text(""" -function ignoredFunc() { +export function ignoredFunc() { return 42; } """) @@ -282,7 +282,7 @@ def test_filter_includes_files_with_dashes(self, tmp_path): """Test that JavaScript files with dashes in name are included (unlike Python).""" js_file = tmp_path / "my-module.js" js_file.write_text(""" -function myFunc() { +export function myFunc() { return 42; } """) @@ -312,11 +312,11 @@ def test_get_functions_from_file(self, tmp_path): """Test getting functions to optimize from a JavaScript file.""" js_file = tmp_path / "string_utils.js" js_file.write_text(""" -function reverseString(str) { +export function reverseString(str) { return str.split('').reverse().join(''); } -function capitalize(str) { +export function capitalize(str) { return str.charAt(0).toUpperCase() + str.slice(1); } """) @@ -422,12 +422,12 @@ def test_discover_all_js_functions(self, tmp_path): """Test discovering all JavaScript functions in a directory.""" # Create multiple JS files (tmp_path / "math.js").write_text(""" -function add(a, b) { +export function add(a, b) { return a + b; } """) (tmp_path / "string.js").write_text(""" -function reverse(str) { +export function reverse(str) { return str.split('').reverse().join(''); } """) @@ -451,7 +451,7 @@ def py_func(): return 1 """) (tmp_path / "js_module.js").write_text(""" -function jsFunc() { +export function jsFunc() { return 1; } """) @@ -476,7 +476,7 @@ def test_qualified_name_no_parents(self, tmp_path): """Test qualified name for top-level function.""" js_file = tmp_path / "module.js" js_file.write_text(""" -function topLevel() { +export function topLevel() { return 42; } """) @@ -490,7 +490,7 @@ def test_qualified_name_with_class_parent(self, tmp_path): """Test qualified name for class method.""" js_file = tmp_path / "module.js" js_file.write_text(""" -class MyClass { +export class MyClass { myMethod() { return 42; } @@ -506,7 +506,7 @@ def test_language_attribute(self, tmp_path): """Test that JavaScript functions have correct language attribute.""" js_file = tmp_path / "module.js" js_file.write_text(""" -function myFunc() { +export function myFunc() { return 42; } """) diff --git a/tests/test_languages/fixtures/js_cjs/helpers/format.js b/tests/test_languages/fixtures/js_cjs/helpers/format.js index d2d50e4df..15dae5e1c 100644 --- a/tests/test_languages/fixtures/js_cjs/helpers/format.js +++ b/tests/test_languages/fixtures/js_cjs/helpers/format.js @@ -8,7 +8,7 @@ * @param decimals - Number of decimal places * @returns Formatted number */ -function formatNumber(num, decimals) { +export function formatNumber(num, decimals) { return Number(num.toFixed(decimals)); } @@ -18,7 +18,7 @@ function formatNumber(num, decimals) { * @param name - Parameter name for error message * @throws Error if value is not a valid number */ -function validateInput(value, name) { +export function validateInput(value, name) { if (typeof value !== 'number' || isNaN(value)) { throw new Error(`Invalid ${name}: must be a number`); } @@ -30,7 +30,7 @@ function validateInput(value, name) { * @param symbol - Currency symbol * @returns Formatted currency string */ -function formatCurrency(amount, symbol = '$') { +export function formatCurrency(amount, symbol = '$') { return `${symbol}${formatNumber(amount, 2)}`; } diff --git a/tests/test_languages/fixtures/js_cjs/math_utils.js b/tests/test_languages/fixtures/js_cjs/math_utils.js index 0b650ed0e..a09a4e880 100644 --- a/tests/test_languages/fixtures/js_cjs/math_utils.js +++ b/tests/test_languages/fixtures/js_cjs/math_utils.js @@ -8,7 +8,7 @@ * @param b - Second number * @returns Sum of a and b */ -function add(a, b) { +export function add(a, b) { return a + b; } @@ -18,7 +18,7 @@ function add(a, b) { * @param b - Second number * @returns Product of a and b */ -function multiply(a, b) { +export function multiply(a, b) { return a * b; } @@ -27,7 +27,7 @@ function multiply(a, b) { * @param n - Non-negative integer * @returns Factorial of n */ -function factorial(n) { +export function factorial(n) { // Intentionally inefficient recursive implementation if (n <= 1) return 1; return n * factorial(n - 1); @@ -39,7 +39,7 @@ function factorial(n) { * @param exp - Exponent * @returns base raised to exp */ -function power(base, exp) { +export function power(base, exp) { // Inefficient: linear time instead of log time let result = 1; for (let i = 0; i < exp; i++) { diff --git a/tests/test_languages/test_code_context_extraction.py b/tests/test_languages/test_code_context_extraction.py index 87c728b34..07946ddd3 100644 --- a/tests/test_languages/test_code_context_extraction.py +++ b/tests/test_languages/test_code_context_extraction.py @@ -56,7 +56,7 @@ class TestSimpleFunctionContext: def test_simple_function_no_dependencies(self, js_support, temp_project): """Test extracting context for a simple standalone function without any dependencies.""" code = """\ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -70,7 +70,7 @@ def test_simple_function_no_dependencies(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -84,7 +84,7 @@ def test_simple_function_no_dependencies(self, js_support, temp_project): def test_arrow_function_with_implicit_return(self, js_support, temp_project): """Test extracting an arrow function with implicit return.""" code = """\ -const multiply = (a, b) => a * b; +export const multiply = (a, b) => a * b; """ file_path = temp_project / "math.js" file_path.write_text(code, encoding="utf-8") @@ -97,7 +97,7 @@ def test_arrow_function_with_implicit_return(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -const multiply = (a, b) => a * b; +export const multiply = (a, b) => a * b; """ assert context.target_code == expected_target_code assert context.helper_functions == [] @@ -116,7 +116,7 @@ def test_function_with_simple_jsdoc(self, js_support, temp_project): * @param {number} b - Second number * @returns {number} The sum */ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -129,13 +129,7 @@ def test_function_with_simple_jsdoc(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -/** - * Adds two numbers together. - * @param {number} a - First number - * @param {number} b - Second number - * @returns {number} The sum - */ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -163,7 +157,7 @@ def test_function_with_complex_jsdoc_types(self, js_support, temp_project): * const doubled = await processItems([1, 2, 3], x => x * 2); * // returns [2, 4, 6] */ -async function processItems(items, callback, options = {}) { +export async function processItems(items, callback, options = {}) { const { parallel = false, chunkSize = 100 } = options; if (!Array.isArray(items)) { @@ -187,25 +181,7 @@ def test_function_with_complex_jsdoc_types(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -/** - * Processes an array of items with a callback function. - * - * This function iterates over each item and applies the transformation. - * - * @template T - The type of items in the input array - * @template U - The type of items in the output array - * @param {Array} items - The input array to process - * @param {function(T, number): U} callback - Transformation function - * @param {Object} [options] - Optional configuration - * @param {boolean} [options.parallel=false] - Whether to process in parallel - * @param {number} [options.chunkSize=100] - Size of processing chunks - * @returns {Promise>} The transformed array - * @throws {TypeError} If items is not an array - * @example - * const doubled = await processItems([1, 2, 3], x => x * 2); - * // returns [2, 4, 6] - */ -async function processItems(items, callback, options = {}) { +export async function processItems(items, callback, options = {}) { const { parallel = false, chunkSize = 100 } = options; if (!Array.isArray(items)) { @@ -231,7 +207,7 @@ def test_class_with_jsdoc_on_class_and_methods(self, js_support, temp_project): * @class CacheManager * @description Provides in-memory caching with automatic expiration. */ -class CacheManager { +export class CacheManager { /** * Creates a new cache manager. * @param {number} defaultTTL - Default time-to-live in milliseconds @@ -275,12 +251,6 @@ class CacheManager { context = js_support.extract_code_context(get_or_compute, temp_project, temp_project) expected_target_code = """\ -/** - * A cache implementation with TTL support. - * - * @class CacheManager - * @description Provides in-memory caching with automatic expiration. - */ class CacheManager { /** * Creates a new cache manager. @@ -344,7 +314,7 @@ def test_jsdoc_with_typedef_and_callback(self, js_support, temp_project): * @param {ValidatorFunction[]} validators - Array of validator functions * @returns {ValidationResult} Combined validation result */ -function validateUserData(data, validators) { +export function validateUserData(data, validators) { const errors = []; const fieldErrors = {}; @@ -377,13 +347,7 @@ def test_jsdoc_with_typedef_and_callback(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -/** - * Validates user input data. - * @param {Object} data - The data to validate - * @param {ValidatorFunction[]} validators - Array of validator functions - * @returns {ValidationResult} Combined validation result - */ -function validateUserData(data, validators) { +export function validateUserData(data, validators) { const errors = []; const fieldErrors = {}; @@ -433,7 +397,7 @@ def test_function_with_multiple_complex_constants(self, js_support, temp_project }; const UNUSED_CONFIG = { debug: false }; -async function fetchWithRetry(endpoint, options = {}) { +export async function fetchWithRetry(endpoint, options = {}) { const url = API_BASE_URL + endpoint; let lastError; @@ -473,7 +437,7 @@ def test_function_with_multiple_complex_constants(self, js_support, temp_project context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -async function fetchWithRetry(endpoint, options = {}) { +export async function fetchWithRetry(endpoint, options = {}) { const url = API_BASE_URL + endpoint; let lastError; @@ -537,7 +501,7 @@ def test_function_with_regex_and_template_constants(self, js_support, temp_proje url: 'Please enter a valid URL' }; -function validateField(value, fieldType) { +export function validateField(value, fieldType) { const pattern = PATTERNS[fieldType]; if (!pattern) { return { valid: true, error: null }; @@ -559,7 +523,7 @@ def test_function_with_regex_and_template_constants(self, js_support, temp_proje context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function validateField(value, fieldType) { +export function validateField(value, fieldType) { const pattern = PATTERNS[fieldType]; if (!pattern) { return { valid: true, error: null }; @@ -595,16 +559,16 @@ class TestSameFileHelperFunctions: def test_function_with_chain_of_helpers(self, js_support, temp_project): """Test function calling helper that calls another helper (transitive dependencies).""" code = """\ -function sanitizeString(str) { +export function sanitizeString(str) { return str.trim().toLowerCase(); } -function normalizeInput(input) { +export function normalizeInput(input) { const sanitized = sanitizeString(input); return sanitized.replace(/\\s+/g, '-'); } -function processUserInput(rawInput) { +export function processUserInput(rawInput) { const normalized = normalizeInput(rawInput); return { original: rawInput, @@ -622,7 +586,7 @@ def test_function_with_chain_of_helpers(self, js_support, temp_project): context = js_support.extract_code_context(process_func, temp_project, temp_project) expected_target_code = """\ -function processUserInput(rawInput) { +export function processUserInput(rawInput) { const normalized = normalizeInput(rawInput); return { original: rawInput, @@ -640,23 +604,23 @@ def test_function_with_chain_of_helpers(self, js_support, temp_project): def test_function_with_multiple_unrelated_helpers(self, js_support, temp_project): """Test function calling multiple independent helper functions.""" code = """\ -function formatDate(date) { +export function formatDate(date) { return date.toISOString().split('T')[0]; } -function formatCurrency(amount) { +export function formatCurrency(amount) { return '$' + amount.toFixed(2); } -function formatPercentage(value) { +export function formatPercentage(value) { return (value * 100).toFixed(1) + '%'; } -function unusedFormatter() { +export function unusedFormatter() { return 'not used'; } -function generateReport(data) { +export function generateReport(data) { const date = formatDate(new Date(data.timestamp)); const revenue = formatCurrency(data.revenue); const growth = formatPercentage(data.growth); @@ -677,7 +641,7 @@ def test_function_with_multiple_unrelated_helpers(self, js_support, temp_project context = js_support.extract_code_context(report_func, temp_project, temp_project) expected_target_code = """\ -function generateReport(data) { +export function generateReport(data) { const date = formatDate(new Date(data.timestamp)); const revenue = formatCurrency(data.revenue); const growth = formatPercentage(data.growth); @@ -699,21 +663,21 @@ def test_function_with_multiple_unrelated_helpers(self, js_support, temp_project for helper in context.helper_functions: if helper.name == "formatDate": expected = """\ -function formatDate(date) { +export function formatDate(date) { return date.toISOString().split('T')[0]; } """ assert helper.source_code == expected elif helper.name == "formatCurrency": expected = """\ -function formatCurrency(amount) { +export function formatCurrency(amount) { return '$' + amount.toFixed(2); } """ assert helper.source_code == expected elif helper.name == "formatPercentage": expected = """\ -function formatPercentage(value) { +export function formatPercentage(value) { return (value * 100).toFixed(1) + '%'; } """ @@ -726,7 +690,7 @@ class TestClassMethodWithSiblingMethods: def test_graph_topological_sort(self, js_support, temp_project): """Test graph class with topological sort - similar to Python test_class_method_dependencies.""" code = """\ -class Graph { +export class Graph { constructor(vertices) { this.graph = new Map(); this.V = vertices; @@ -774,7 +738,7 @@ class Graph { context = js_support.extract_code_context(topo_sort, temp_project, temp_project) - # The extracted code should include class wrapper with constructor + # The extracted code should include class wrapper with constructor and sibling methods used expected_target_code = """\ class Graph { constructor(vertices) { @@ -794,6 +758,19 @@ class Graph { return stack; } + + topologicalSortUtil(v, visited, stack) { + visited[v] = true; + + const neighbors = this.graph.get(v) || []; + for (const i of neighbors) { + if (visited[i] === false) { + this.topologicalSortUtil(i, visited, stack); + } + } + + stack.unshift(v); + } } """ assert context.target_code == expected_target_code @@ -802,7 +779,7 @@ class Graph { def test_class_method_using_nested_helper_class(self, js_support, temp_project): """Test class method that uses another class as a helper - mirrors Python HelperClass test.""" code = """\ -class HelperClass { +export class HelperClass { constructor(name) { this.name = name; } @@ -816,7 +793,7 @@ class HelperClass { } } -class NestedHelper { +export class NestedHelper { constructor(name) { this.name = name; } @@ -826,11 +803,11 @@ class NestedHelper { } } -function mainMethod() { +export function mainMethod() { return 'hello'; } -class MainClass { +export class MainClass { constructor(name) { this.name = name; } @@ -890,7 +867,7 @@ def test_helper_from_another_file_commonjs(self, js_support, temp_project): main_code = """\ const { sorter } = require('./bubble_sort_with_math'); -function sortFromAnotherFile(arr) { +export function sortFromAnotherFile(arr) { const sortedArr = sorter(arr); return sortedArr; } @@ -906,7 +883,7 @@ def test_helper_from_another_file_commonjs(self, js_support, temp_project): context = js_support.extract_code_context(main_func, temp_project, temp_project) expected_target_code = """\ -function sortFromAnotherFile(arr) { +export function sortFromAnotherFile(arr) { const sortedArr = sorter(arr); return sortedArr; } @@ -943,12 +920,10 @@ def test_helper_from_another_file_esm(self, js_support, temp_project): main_code = """\ import identity, { double, triple } from './utils'; -function processNumber(n) { +export function processNumber(n) { const base = identity(n); return double(base) + triple(base); } - -export { processNumber }; """ main_path = temp_project / "main.js" main_path.write_text(main_code, encoding="utf-8") @@ -959,7 +934,7 @@ def test_helper_from_another_file_esm(self, js_support, temp_project): context = js_support.extract_code_context(process_func, temp_project, temp_project) expected_target_code = """\ -function processNumber(n) { +export function processNumber(n) { const base = identity(n); return double(base) + triple(base); } @@ -1007,7 +982,7 @@ def test_chained_imports_across_three_files(self, js_support, temp_project): main_code = """\ import { transformInput } from './middleware'; -function handleUserInput(rawInput) { +export function handleUserInput(rawInput) { try { const result = transformInput(rawInput); return { success: true, data: result }; @@ -1015,8 +990,6 @@ def test_chained_imports_across_three_files(self, js_support, temp_project): return { success: false, error: error.message }; } } - -export { handleUserInput }; """ main_path = temp_project / "main.js" main_path.write_text(main_code, encoding="utf-8") @@ -1027,7 +1000,7 @@ def test_chained_imports_across_three_files(self, js_support, temp_project): context = js_support.extract_code_context(handle_func, temp_project, temp_project) expected_target_code = """\ -function handleUserInput(rawInput) { +export function handleUserInput(rawInput) { try { const result = transformInput(rawInput); return { success: true, data: result }; @@ -1059,7 +1032,7 @@ def test_function_with_complex_generic_types(self, ts_support, temp_project): type Entity = T & Identifiable & Timestamped; -function createEntity(data: T): Entity { +export function createEntity(data: T): Entity { const now = new Date(); return { ...data, @@ -1078,7 +1051,7 @@ def test_function_with_complex_generic_types(self, ts_support, temp_project): context = ts_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function createEntity(data: T): Entity { +export function createEntity(data: T): Entity { const now = new Date(); return { ...data, @@ -1117,7 +1090,7 @@ def test_class_with_private_fields_and_typed_methods(self, ts_support, temp_proj maxSize: number; } -class TypedCache { +export class TypedCache { private readonly cache: Map>; private readonly config: CacheConfig; @@ -1235,15 +1208,13 @@ def test_typescript_with_type_imports(self, ts_support, temp_project): const DEFAULT_ROLE: UserRole = 'user'; -function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE): User { +export function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE): User { return { id: Math.random().toString(36).substring(2), name: input.name, email: input.email }; } - -export { createUser }; """ service_path = temp_project / "service.ts" service_path.write_text(service_code, encoding="utf-8") @@ -1254,7 +1225,7 @@ def test_typescript_with_type_imports(self, ts_support, temp_project): context = ts_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE): User { +export function createUser(input: CreateUserInput, role: UserRole = DEFAULT_ROLE): User { return { id: Math.random().toString(36).substring(2), name: input.name, @@ -1294,7 +1265,7 @@ class TestRecursionAndCircularDependencies: def test_self_recursive_factorial(self, js_support, temp_project): """Test self-recursive function does not list itself as helper.""" code = """\ -function factorial(n) { +export function factorial(n) { if (n <= 1) return 1; return n * factorial(n - 1); } @@ -1308,7 +1279,7 @@ def test_self_recursive_factorial(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function factorial(n) { +export function factorial(n) { if (n <= 1) return 1; return n * factorial(n - 1); } @@ -1319,12 +1290,12 @@ def test_self_recursive_factorial(self, js_support, temp_project): def test_mutually_recursive_even_odd(self, js_support, temp_project): """Test mutually recursive functions.""" code = """\ -function isEven(n) { +export function isEven(n) { if (n === 0) return true; return isOdd(n - 1); } -function isOdd(n) { +export function isOdd(n) { if (n === 0) return false; return isEven(n - 1); } @@ -1338,7 +1309,7 @@ def test_mutually_recursive_even_odd(self, js_support, temp_project): context = js_support.extract_code_context(is_even, temp_project, temp_project) expected_target_code = """\ -function isEven(n) { +export function isEven(n) { if (n === 0) return true; return isOdd(n - 1); } @@ -1351,7 +1322,7 @@ def test_mutually_recursive_even_odd(self, js_support, temp_project): # Verify helper source assert context.helper_functions[0].source_code == """\ -function isOdd(n) { +export function isOdd(n) { if (n === 0) return false; return isEven(n - 1); } @@ -1360,28 +1331,28 @@ def test_mutually_recursive_even_odd(self, js_support, temp_project): def test_complex_recursive_tree_traversal(self, js_support, temp_project): """Test complex recursive tree traversal with multiple recursive calls.""" code = """\ -function traversePreOrder(node, visit) { +export function traversePreOrder(node, visit) { if (!node) return; visit(node.value); traversePreOrder(node.left, visit); traversePreOrder(node.right, visit); } -function traverseInOrder(node, visit) { +export function traverseInOrder(node, visit) { if (!node) return; traverseInOrder(node.left, visit); visit(node.value); traverseInOrder(node.right, visit); } -function traversePostOrder(node, visit) { +export function traversePostOrder(node, visit) { if (!node) return; traversePostOrder(node.left, visit); traversePostOrder(node.right, visit); visit(node.value); } -function collectAllValues(root) { +export function collectAllValues(root) { const values = { pre: [], in: [], post: [] }; traversePreOrder(root, v => values.pre.push(v)); @@ -1400,7 +1371,7 @@ def test_complex_recursive_tree_traversal(self, js_support, temp_project): context = js_support.extract_code_context(collect_func, temp_project, temp_project) expected_target_code = """\ -function collectAllValues(root) { +export function collectAllValues(root) { const values = { pre: [], in: [], post: [] }; traversePreOrder(root, v => values.pre.push(v)); @@ -1423,7 +1394,7 @@ class TestAsyncPatternsAndPromises: def test_async_function_chain(self, js_support, temp_project): """Test async function that calls other async functions.""" code = """\ -async function fetchUserById(id) { +export async function fetchUserById(id) { const response = await fetch(`/api/users/${id}`); if (!response.ok) { throw new Error(`User ${id} not found`); @@ -1431,17 +1402,17 @@ def test_async_function_chain(self, js_support, temp_project): return response.json(); } -async function fetchUserPosts(userId) { +export async function fetchUserPosts(userId) { const response = await fetch(`/api/users/${userId}/posts`); return response.json(); } -async function fetchUserComments(userId) { +export async function fetchUserComments(userId) { const response = await fetch(`/api/users/${userId}/comments`); return response.json(); } -async function fetchUserProfile(userId) { +export async function fetchUserProfile(userId) { const user = await fetchUserById(userId); const [posts, comments] = await Promise.all([ fetchUserPosts(userId), @@ -1465,7 +1436,7 @@ def test_async_function_chain(self, js_support, temp_project): context = js_support.extract_code_context(profile_func, temp_project, temp_project) expected_target_code = """\ -async function fetchUserProfile(userId) { +export async function fetchUserProfile(userId) { const user = await fetchUserById(userId); const [posts, comments] = await Promise.all([ fetchUserPosts(userId), @@ -1493,7 +1464,7 @@ class TestExtractionReplacementRoundTrip: def test_extract_and_replace_class_method(self, js_support, temp_project): """Test extracting code context and then replacing the method.""" original_source = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -1536,7 +1507,7 @@ class Counter { # Step 2: Simulate AI returning optimized code optimized_code_from_ai = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -1551,7 +1522,7 @@ class Counter { result = js_support.replace_function(original_source, increment_func, optimized_code_from_ai) expected_result = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -1578,7 +1549,7 @@ class TestEdgeCases: def test_function_with_complex_destructuring(self, js_support, temp_project): """Test function with complex nested destructuring parameters.""" code = """\ -function processApiResponse({ +export function processApiResponse({ data: { users = [], meta: { total, page } = {} } = {}, status, headers: { 'content-type': contentType } = {} @@ -1600,7 +1571,7 @@ def test_function_with_complex_destructuring(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function processApiResponse({ +export function processApiResponse({ data: { users = [], meta: { total, page } = {} } = {}, status, headers: { 'content-type': contentType } = {} @@ -1619,13 +1590,13 @@ def test_function_with_complex_destructuring(self, js_support, temp_project): def test_generator_function(self, js_support, temp_project): """Test generator function extraction.""" code = """\ -function* range(start, end, step = 1) { +export function* range(start, end, step = 1) { for (let i = start; i < end; i += step) { yield i; } } -function* fibonacci(limit) { +export function* fibonacci(limit) { let [a, b] = [0, 1]; while (a < limit) { yield a; @@ -1642,7 +1613,7 @@ def test_generator_function(self, js_support, temp_project): context = js_support.extract_code_context(range_func, temp_project, temp_project) expected_target_code = """\ -function* range(start, end, step = 1) { +export function* range(start, end, step = 1) { for (let i = start; i < end; i += step) { yield i; } @@ -1660,7 +1631,7 @@ def test_function_with_computed_property_names(self, js_support, temp_project): AGE: 'user_age' }; -function createUserObject(name, email, age) { +export function createUserObject(name, email, age) { return { [FIELD_KEYS.NAME]: name, [FIELD_KEYS.EMAIL]: email, @@ -1677,7 +1648,7 @@ def test_function_with_computed_property_names(self, js_support, temp_project): context = js_support.extract_code_context(func, temp_project, temp_project) expected_target_code = """\ -function createUserObject(name, email, age) { +export function createUserObject(name, email, age) { return { [FIELD_KEYS.NAME]: name, [FIELD_KEYS.EMAIL]: email, @@ -1937,7 +1908,7 @@ class TestContextProperties: def test_javascript_context_has_correct_language(self, js_support, temp_project): """Test that JavaScript context has correct language property.""" code = """\ -function test() { +export function test() { return 1; } """ @@ -1956,7 +1927,7 @@ def test_javascript_context_has_correct_language(self, js_support, temp_project) def test_typescript_context_has_javascript_language(self, ts_support, temp_project): """Test that TypeScript context uses JavaScript language enum.""" code = """\ -function test(): number { +export function test(): number { return 1; } """ @@ -1977,7 +1948,7 @@ class TestContextValidation: def test_all_class_methods_produce_valid_syntax(self, js_support, temp_project): """Test that all extracted class methods are syntactically valid JavaScript.""" code = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } diff --git a/tests/test_languages/test_function_discovery_integration.py b/tests/test_languages/test_function_discovery_integration.py index 621a00d79..c91f91fe5 100644 --- a/tests/test_languages/test_function_discovery_integration.py +++ b/tests/test_languages/test_function_discovery_integration.py @@ -89,11 +89,11 @@ def test_javascript_file_routes_to_js_handler(self): """Test that JavaScript files use the JavaScript handler.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -function add(a, b) { +export function add(a, b) { return a + b; } -function multiply(a, b) { +export function multiply(a, b) { return a * b; } """) @@ -124,7 +124,7 @@ def test_function_to_optimize_has_correct_fields(self): """Test that FunctionToOptimize has all required fields populated.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -162,7 +162,7 @@ def add(a, b): def test_discovers_javascript_files_when_specified(self, tmp_path): """Test that JavaScript files are discovered when language is specified.""" (tmp_path / "module.js").write_text(""" -function add(a, b) { +export function add(a, b) { return a + b; } """) @@ -177,7 +177,7 @@ def py_func(): return 1 """) (tmp_path / "js_module.js").write_text(""" -function jsFunc() { +export function jsFunc() { return 1; } """) diff --git a/tests/test_languages/test_javascript_e2e.py b/tests/test_languages/test_javascript_e2e.py index 2fe25c18a..017e8f66e 100644 --- a/tests/test_languages/test_javascript_e2e.py +++ b/tests/test_languages/test_javascript_e2e.py @@ -155,16 +155,16 @@ def test_replace_function_in_javascript_file(self): from codeflash.languages.base import FunctionInfo original_source = """ -function add(a, b) { +export function add(a, b) { return a + b; } -function multiply(a, b) { +export function multiply(a, b) { return a * b; } """ - new_function = """function add(a, b) { + new_function = """export function add(a, b) { // Optimized version return a + b; }""" @@ -178,12 +178,12 @@ def test_replace_function_in_javascript_file(self): result = js_support.replace_function(original_source, func_info, new_function) expected_result = """ -function add(a, b) { +export function add(a, b) { // Optimized version return a + b; } -function multiply(a, b) { +export function multiply(a, b) { return a * b; } """ @@ -234,7 +234,7 @@ def test_function_to_optimize_has_correct_fields(self): with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -244,7 +244,7 @@ class Calculator { } } -function standalone(x) { +export function standalone(x) { return x * 2; } """) diff --git a/tests/test_languages/test_javascript_instrumentation.py b/tests/test_languages/test_javascript_instrumentation.py index ba25a3af5..e3457c231 100644 --- a/tests/test_languages/test_javascript_instrumentation.py +++ b/tests/test_languages/test_javascript_instrumentation.py @@ -663,4 +663,314 @@ def test_this_method_call_exact_output(self): expected = " return codeflash.capture('Class.fibonacci', '1', this.fibonacci.bind(this), n - 1);" assert transformed == expected, f"Expected:\n{expected}\nGot:\n{transformed}" - assert counter == 1 \ No newline at end of file + assert counter == 1 + + +class TestFixImportsInsideTestBlocks: + """Tests for fix_imports_inside_test_blocks function.""" + + def test_fix_named_import_inside_test_block(self): + """Test fixing named import inside test function.""" + from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks + + code = """ +test('should work', () => { + const mock = jest.fn(); + import { foo } from '../src/module'; + expect(foo()).toBe(true); +}); +""" + fixed = fix_imports_inside_test_blocks(code) + + assert "const { foo } = require('../src/module');" in fixed + assert "import { foo }" not in fixed + + def test_fix_default_import_inside_test_block(self): + """Test fixing default import inside test function.""" + from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks + + code = """ +test('should work', () => { + env.isTest.mockReturnValue(false); + import queuesModule from '../src/queue/queue'; + expect(queuesModule).toBeDefined(); +}); +""" + fixed = fix_imports_inside_test_blocks(code) + + assert "const queuesModule = require('../src/queue/queue');" in fixed + assert "import queuesModule from" not in fixed + + def test_fix_namespace_import_inside_test_block(self): + """Test fixing namespace import inside test function.""" + from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks + + code = """ +test('should work', () => { + import * as utils from '../src/utils'; + expect(utils.foo()).toBe(true); +}); +""" + fixed = fix_imports_inside_test_blocks(code) + + assert "const utils = require('../src/utils');" in fixed + assert "import * as utils" not in fixed + + def test_preserve_top_level_imports(self): + """Test that top-level imports are not modified.""" + from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks + + code = """ +import { jest, describe, test, expect } from '@jest/globals'; +import { foo } from '../src/module'; + +describe('test suite', () => { + test('should work', () => { + expect(foo()).toBe(true); + }); +}); +""" + fixed = fix_imports_inside_test_blocks(code) + + # Top-level imports should remain unchanged + assert "import { jest, describe, test, expect } from '@jest/globals';" in fixed + assert "import { foo } from '../src/module';" in fixed + + def test_empty_code(self): + """Test handling empty code.""" + from codeflash.languages.javascript.instrument import fix_imports_inside_test_blocks + + assert fix_imports_inside_test_blocks("") == "" + assert fix_imports_inside_test_blocks(" ") == " " + + +class TestFixJestMockPaths: + """Tests for fix_jest_mock_paths function.""" + + def test_fix_mock_path_when_source_relative(self): + """Test fixing mock path that's relative to source file.""" + from codeflash.languages.javascript.instrument import fix_jest_mock_paths + + with tempfile.TemporaryDirectory() as tmpdir: + # Create directory structure + src_dir = Path(tmpdir) / "src" / "queue" + tests_dir = Path(tmpdir) / "tests" + env_file = Path(tmpdir) / "src" / "environment.ts" + + src_dir.mkdir(parents=True) + tests_dir.mkdir(parents=True) + env_file.parent.mkdir(parents=True, exist_ok=True) + env_file.write_text("export const env = {};") + + source_file = src_dir / "queue.ts" + source_file.write_text("import env from '../environment';") + + test_file = tests_dir / "test_queue.test.ts" + + # Test code with incorrect mock path (relative to source, not test) + test_code = """ +import { jest, describe, test, expect } from '@jest/globals'; +jest.mock('../environment'); +jest.mock('../redis/utils'); + +describe('queue', () => { + test('works', () => {}); +}); +""" + fixed = fix_jest_mock_paths(test_code, test_file, source_file, tests_dir) + + # Should fix the path to be relative to the test file + assert "jest.mock('../src/environment')" in fixed + + def test_preserve_valid_mock_path(self): + """Test that valid mock paths are not modified.""" + from codeflash.languages.javascript.instrument import fix_jest_mock_paths + + with tempfile.TemporaryDirectory() as tmpdir: + # Create directory structure + src_dir = Path(tmpdir) / "src" + tests_dir = Path(tmpdir) / "tests" + + src_dir.mkdir(parents=True) + tests_dir.mkdir(parents=True) + + # Create the file being mocked at the correct location + mock_file = src_dir / "utils.ts" + mock_file.write_text("export const utils = {};") + + source_file = src_dir / "main.ts" + source_file.write_text("") + test_file = tests_dir / "test_main.test.ts" + + # Test code with correct mock path (valid from test location) + test_code = """ +jest.mock('../src/utils'); + +describe('main', () => { + test('works', () => {}); +}); +""" + fixed = fix_jest_mock_paths(test_code, test_file, source_file, tests_dir) + + # Should keep the path unchanged since it's valid + assert "jest.mock('../src/utils')" in fixed + + def test_fix_doMock_path(self): + """Test fixing jest.doMock path.""" + from codeflash.languages.javascript.instrument import fix_jest_mock_paths + + with tempfile.TemporaryDirectory() as tmpdir: + # Create directory structure: src/queue/queue.ts imports ../environment (-> src/environment.ts) + src_dir = Path(tmpdir) / "src" + queue_dir = src_dir / "queue" + tests_dir = Path(tmpdir) / "tests" + env_file = src_dir / "environment.ts" + + queue_dir.mkdir(parents=True) + tests_dir.mkdir(parents=True) + env_file.write_text("export const env = {};") + + source_file = queue_dir / "queue.ts" + source_file.write_text("") + test_file = tests_dir / "test_queue.test.ts" + + # From src/queue/queue.ts, ../environment resolves to src/environment.ts + # Test file is at tests/test_queue.test.ts + # So the correct mock path from test should be ../src/environment + test_code = """ +jest.doMock('../environment', () => ({ isTest: jest.fn() })); +""" + fixed = fix_jest_mock_paths(test_code, test_file, source_file, tests_dir) + + # Should fix the doMock path + assert "jest.doMock('../src/environment'" in fixed + + def test_empty_code(self): + """Test handling empty code.""" + from codeflash.languages.javascript.instrument import fix_jest_mock_paths + + with tempfile.TemporaryDirectory() as tmpdir: + tests_dir = Path(tmpdir) / "tests" + tests_dir.mkdir() + source_file = Path(tmpdir) / "src" / "main.ts" + test_file = tests_dir / "test.ts" + + assert fix_jest_mock_paths("", test_file, source_file, tests_dir) == "" + assert fix_jest_mock_paths(" ", test_file, source_file, tests_dir) == " " + + +class TestFunctionCallsInStrings: + """Tests for skipping function calls inside string literals.""" + + def test_skip_function_in_test_description_single_quotes(self): + """Test that function calls in single-quoted test descriptions are not transformed.""" + from codeflash.languages.javascript.instrument import transform_standalone_calls + + func = make_func("fibonacci") + code = """ +test('should compute fibonacci(20) and fibonacci(30) to known values', () => { + const result = fibonacci(10); + expect(result).toBe(55); +}); +""" + transformed, _counter = transform_standalone_calls(code, func, "capture") + + # The function call in the test description should NOT be transformed + assert "fibonacci(20)" in transformed + assert "fibonacci(30)" in transformed + # The actual call should be transformed + assert "codeflash.capture('fibonacci'" in transformed + + def test_skip_function_in_test_description_double_quotes(self): + """Test that function calls in double-quoted test descriptions are not transformed.""" + from codeflash.languages.javascript.instrument import transform_standalone_calls + + func = make_func("fibonacci") + code = ''' +test("should compute fibonacci(20) correctly", () => { + const result = fibonacci(10); +}); +''' + transformed, _counter = transform_standalone_calls(code, func, "capture") + + # The function call in the test description should NOT be transformed + assert 'fibonacci(20)' in transformed + # The actual call should be transformed + assert "codeflash.capture('fibonacci'" in transformed + + def test_skip_function_in_template_literal(self): + """Test that function calls in template literals are not transformed.""" + from codeflash.languages.javascript.instrument import transform_standalone_calls + + func = make_func("fibonacci") + code = """ +test(`should compute fibonacci(20) correctly`, () => { + const result = fibonacci(10); +}); +""" + transformed, _counter = transform_standalone_calls(code, func, "capture") + + # The function call in the template literal should NOT be transformed + assert "fibonacci(20)" in transformed + # The actual call should be transformed + assert "codeflash.capture('fibonacci'" in transformed + + def test_skip_expect_in_string_literal(self): + """Test that expect(func()) in string literals is not transformed.""" + from codeflash.languages.javascript.instrument import transform_expect_calls + + func = make_func("fibonacci") + code = """ +describe('testing expect(fibonacci(n)) patterns', () => { + test('works', () => { + expect(fibonacci(10)).toBe(55); + }); +}); +""" + transformed, _counter = transform_expect_calls(code, func, "capture") + + # The expect in the describe string should NOT be transformed + assert "expect(fibonacci(n))" in transformed + # The actual expect call should be transformed + assert "codeflash.capture('fibonacci'" in transformed + + def test_handle_escaped_quotes_in_string(self): + """Test that escaped quotes in strings are handled correctly.""" + from codeflash.languages.javascript.instrument import transform_standalone_calls + + func = make_func("fibonacci") + code = """ +test('test \\'fibonacci(5)\\' escaping', () => { + const result = fibonacci(10); +}); +""" + transformed, _counter = transform_standalone_calls(code, func, "capture") + + # The function call in the escaped string should NOT be transformed + assert "fibonacci(5)" in transformed + # The actual call should be transformed + assert "codeflash.capture('fibonacci'" in transformed + + def test_is_inside_string_helper(self): + """Test the is_inside_string helper function directly.""" + from codeflash.languages.javascript.instrument import is_inside_string + + # Position inside single-quoted string + code1 = "test('fibonacci(5)', () => {})" + assert is_inside_string(code1, 10) is True # Inside the string + + # Position outside string + assert is_inside_string(code1, 0) is False # Before string + assert is_inside_string(code1, 25) is False # After string + + # Double quotes + code2 = 'test("fibonacci(5)", () => {})' + assert is_inside_string(code2, 10) is True + + # Template literal + code3 = "test(`fibonacci(5)`, () => {})" + assert is_inside_string(code3, 10) is True + + # Escaped quote doesn't end string + code4 = "test('fib\\'s result', () => {})" + assert is_inside_string(code4, 15) is True # Still inside after escaped quote \ No newline at end of file diff --git a/tests/test_languages/test_javascript_optimization_flow.py b/tests/test_languages/test_javascript_optimization_flow.py index 7c7ba5aa6..26d2db140 100644 --- a/tests/test_languages/test_javascript_optimization_flow.py +++ b/tests/test_languages/test_javascript_optimization_flow.py @@ -60,6 +60,7 @@ def test_function_to_optimize_has_correct_language_for_javascript(self, tmp_path function add(a, b) { return a + b; } +module.exports = { add }; """) functions = find_all_functions_in_file(js_file) diff --git a/tests/test_languages/test_javascript_support.py b/tests/test_languages/test_javascript_support.py index 7a6868a66..8a7f9afe1 100644 --- a/tests/test_languages/test_javascript_support.py +++ b/tests/test_languages/test_javascript_support.py @@ -46,7 +46,7 @@ def test_discover_simple_function(self, js_support): """Test discovering a simple function declaration.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -function add(a, b) { +export function add(a, b) { return a + b; } """) @@ -62,15 +62,15 @@ def test_discover_multiple_functions(self, js_support): """Test discovering multiple functions.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -function add(a, b) { +export function add(a, b) { return a + b; } -function subtract(a, b) { +export function subtract(a, b) { return a - b; } -function multiply(a, b) { +export function multiply(a, b) { return a * b; } """) @@ -86,11 +86,11 @@ def test_discover_arrow_function(self, js_support): """Test discovering arrow functions assigned to variables.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -const add = (a, b) => { +export const add = (a, b) => { return a + b; }; -const multiply = (x, y) => x * y; +export const multiply = (x, y) => x * y; """) f.flush() @@ -104,11 +104,11 @@ def test_discover_function_without_return_excluded(self, js_support): """Test that functions without return are excluded by default.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -function withReturn() { +export function withReturn() { return 1; } -function withoutReturn() { +export function withoutReturn() { console.log("hello"); } """) @@ -124,7 +124,7 @@ def test_discover_class_methods(self, js_support): """Test discovering class methods.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -147,11 +147,11 @@ def test_discover_async_functions(self, js_support): """Test discovering async functions.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -async function fetchData(url) { +export async function fetchData(url) { return await fetch(url); } -function syncFunction() { +export function syncFunction() { return 1; } """) @@ -171,11 +171,11 @@ def test_discover_with_filter_exclude_async(self, js_support): """Test filtering out async functions.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -async function asyncFunc() { +export async function asyncFunc() { return 1; } -function syncFunc() { +export function syncFunc() { return 2; } """) @@ -191,11 +191,11 @@ def test_discover_with_filter_exclude_methods(self, js_support): """Test filtering out class methods.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -function standalone() { +export function standalone() { return 1; } -class MyClass { +export class MyClass { method() { return 2; } @@ -212,11 +212,11 @@ class MyClass { def test_discover_line_numbers(self, js_support): """Test that line numbers are correctly captured.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""function func1() { + f.write("""export function func1() { return 1; } -function func2() { +export function func2() { const x = 1; const y = 2; return x + y; @@ -238,7 +238,7 @@ def test_discover_generator_function(self, js_support): """Test discovering generator functions.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -function* numberGenerator() { +export function* numberGenerator() { yield 1; yield 2; return 3; @@ -271,7 +271,7 @@ def test_discover_function_expression(self, js_support): """Test discovering function expressions.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -const add = function(a, b) { +export const add = function(a, b) { return a + b; }; """) @@ -290,7 +290,7 @@ def test_discover_immediately_invoked_function_excluded(self, js_support): return 1; })(); -function named() { +export function named() { return 2; } """) @@ -476,7 +476,7 @@ class TestExtractCodeContext: def test_extract_simple_function(self, js_support): """Test extracting context for a simple function.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""function add(a, b) { + f.write("""export function add(a, b) { return a + b; } """) @@ -495,11 +495,11 @@ def test_extract_simple_function(self, js_support): def test_extract_with_helper(self, js_support): """Test extracting context with helper functions.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""function helper(x) { + f.write("""export function helper(x) { return x * 2; } -function main(a) { +export function main(a) { return helper(a) + 1; } """) @@ -523,7 +523,7 @@ class TestIntegration: def test_discover_and_replace_workflow(self, js_support): """Test full discover -> replace workflow.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - original_code = """function fibonacci(n) { + original_code = """export function fibonacci(n) { if (n <= 1) { return n; } @@ -541,7 +541,7 @@ def test_discover_and_replace_workflow(self, js_support): assert func.function_name == "fibonacci" # Replace - optimized_code = """function fibonacci(n) { + optimized_code = """export function fibonacci(n) { // Memoized version const memo = {0: 0, 1: 1}; for (let i = 2; i <= n; i++) { @@ -561,7 +561,7 @@ def test_multiple_classes_and_functions(self, js_support): """Test discovering and working with complex file.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -571,13 +571,13 @@ class Calculator { } } -class StringUtils { +export class StringUtils { reverse(s) { return s.split('').reverse().join(''); } } -function standalone() { +export function standalone() { return 42; } """) @@ -605,11 +605,11 @@ def test_jsx_file(self, js_support): f.write(""" import React from 'react'; -function Button({ onClick, children }) { +export function Button({ onClick, children }) { return ; } -const Card = ({ title, content }) => { +export const Card = ({ title, content }) => { return (

{title}

@@ -673,7 +673,7 @@ class TestClassMethodExtraction: def test_extract_class_method_wraps_in_class(self, js_support): """Test that extracting a class method wraps it in a class definition.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class Calculator { + f.write("""export class Calculator { add(a, b) { return a + b; } @@ -694,6 +694,7 @@ def test_extract_class_method_wraps_in_class(self, js_support): context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) # Full string equality check for exact extraction output + # Note: export keyword is not included in extracted class wrapper expected_code = """class Calculator { add(a, b) { return a + b; @@ -709,7 +710,7 @@ def test_extract_class_method_with_jsdoc(self, js_support): f.write("""/** * A simple calculator class. */ -class Calculator { +export class Calculator { /** * Adds two numbers. * @param {number} a - First number @@ -730,10 +731,9 @@ class Calculator { context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) # Full string equality check - includes class JSDoc, class definition, method JSDoc, and method - expected_code = """/** - * A simple calculator class. - */ -class Calculator { + # Note: export keyword is not included in extracted class wrapper + # Note: Class-level JSDoc is not included when extracting a method + expected_code = """class Calculator { /** * Adds two numbers. * @param {number} a - First number @@ -751,7 +751,7 @@ class Calculator { def test_extract_class_method_syntax_valid(self, js_support): """Test that extracted class method code is always syntactically valid.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class FibonacciCalculator { + f.write("""export class FibonacciCalculator { fibonacci(n) { if (n <= 1) { return n; @@ -769,6 +769,7 @@ def test_extract_class_method_syntax_valid(self, js_support): context = js_support.extract_code_context(fib_method, file_path.parent, file_path.parent) # Full string equality check + # Note: export keyword is not included in extracted class wrapper expected_code = """class FibonacciCalculator { fibonacci(n) { if (n <= 1) { @@ -784,7 +785,7 @@ def test_extract_class_method_syntax_valid(self, js_support): def test_extract_nested_class_method(self, js_support): """Test extracting a method from a nested class structure.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class Outer { + f.write("""export class Outer { createInner() { return class Inner { getValue() { @@ -808,6 +809,7 @@ def test_extract_nested_class_method(self, js_support): context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) # Full string equality check + # Note: export keyword is not included in extracted class wrapper expected_code = """class Outer { add(a, b) { return a + b; @@ -820,7 +822,7 @@ def test_extract_nested_class_method(self, js_support): def test_extract_async_class_method(self, js_support): """Test extracting an async class method.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class ApiClient { + f.write("""export class ApiClient { async fetchData(url) { const response = await fetch(url); return response.json(); @@ -836,6 +838,7 @@ def test_extract_async_class_method(self, js_support): context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent) # Full string equality check + # Note: export keyword is not included in extracted class wrapper expected_code = """class ApiClient { async fetchData(url) { const response = await fetch(url); @@ -849,7 +852,7 @@ def test_extract_async_class_method(self, js_support): def test_extract_static_class_method(self, js_support): """Test extracting a static class method.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class MathUtils { + f.write("""export class MathUtils { static add(a, b) { return a + b; } @@ -869,6 +872,7 @@ def test_extract_static_class_method(self, js_support): context = js_support.extract_code_context(add_method, file_path.parent, file_path.parent) # Full string equality check + # Note: export keyword is not included in extracted class wrapper expected_code = """class MathUtils { static add(a, b) { return a + b; @@ -881,7 +885,7 @@ def test_extract_static_class_method(self, js_support): def test_extract_class_method_without_class_jsdoc(self, js_support): """Test extracting a method from a class without JSDoc.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class SimpleClass { + f.write("""export class SimpleClass { simpleMethod() { return "hello"; } @@ -896,6 +900,7 @@ def test_extract_class_method_without_class_jsdoc(self, js_support): context = js_support.extract_code_context(method, file_path.parent, file_path.parent) # Full string equality check + # Note: export keyword is not included in extracted class wrapper expected_code = """class SimpleClass { simpleMethod() { return "hello"; @@ -1061,7 +1066,7 @@ class TestClassMethodEdgeCases: def test_class_with_constructor(self, js_support): """Test handling classes with constructors.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class Counter { + f.write("""export class Counter { constructor(start = 0) { this.value = start; } @@ -1083,7 +1088,7 @@ def test_class_with_constructor(self, js_support): def test_class_with_getters_setters(self, js_support): """Test handling classes with getters and setters.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class Person { + f.write("""export class Person { constructor(name) { this._name = name; } @@ -1113,13 +1118,13 @@ def test_class_with_getters_setters(self, js_support): def test_class_extending_another(self, js_support): """Test handling classes that extend another class.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class Animal { + f.write("""export class Animal { speak() { return 'sound'; } } -class Dog extends Animal { +export class Dog extends Animal { speak() { return 'bark'; } @@ -1141,6 +1146,7 @@ class Dog extends Animal { context = js_support.extract_code_context(fetch_method, file_path.parent, file_path.parent) # Full string equality check + # Note: export keyword is not included in extracted class wrapper expected_code = """class Dog { fetch() { return 'ball'; @@ -1153,7 +1159,7 @@ class Dog extends Animal { def test_class_with_private_method(self, js_support): """Test handling classes with private methods (ES2022+).""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class SecureClass { + f.write("""export class SecureClass { #privateMethod() { return 'secret'; } @@ -1175,7 +1181,7 @@ def test_class_with_private_method(self, js_support): def test_commonjs_class_export(self, js_support): """Test handling CommonJS exported classes.""" with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: - f.write("""class Calculator { + f.write("""export class Calculator { add(a, b) { return a + b; } @@ -1236,7 +1242,7 @@ def test_extract_context_then_replace_method(self, js_support): 3. Replace extracts just the method body and replaces in original """ original_source = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -1303,7 +1309,7 @@ class Counter { # Verify result with exact string equality expected_result = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -1333,7 +1339,7 @@ def test_typescript_extract_context_then_replace_method(self): ts_support = TypeScriptSupport() original_source = """\ -class User { +export class User { private name: string; private age: number; @@ -1350,8 +1356,6 @@ class User { return this.age; } } - -export { User }; """ with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: f.write(original_source) @@ -1408,7 +1412,7 @@ class User { # Verify result with exact string equality expected_result = """\ -class User { +export class User { private name: string; private age: number; @@ -1426,8 +1430,6 @@ class User { return this.age; } } - -export { User }; """ assert result == expected_result, ( f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}" @@ -1437,7 +1439,7 @@ class User { def test_extract_replace_preserves_other_methods(self, js_support): """Test that replacing one method doesn't affect others.""" original_source = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } @@ -1499,7 +1501,7 @@ class Calculator { # Verify result with exact string equality expected_result = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } @@ -1525,7 +1527,7 @@ class Calculator { def test_extract_static_method_then_replace(self, js_support): """Test extracting and replacing a static method.""" original_source = """\ -class MathUtils { +export class MathUtils { constructor() { this.cache = {}; } @@ -1538,8 +1540,6 @@ class MathUtils { return a * b; } } - -module.exports = { MathUtils }; """ with tempfile.NamedTemporaryFile(suffix=".js", mode="w", delete=False) as f: f.write(original_source) @@ -1586,7 +1586,7 @@ class MathUtils { # Verify result with exact string equality expected_result = """\ -class MathUtils { +export class MathUtils { constructor() { this.cache = {}; } @@ -1600,8 +1600,6 @@ class MathUtils { return a * b; } } - -module.exports = { MathUtils }; """ assert result == expected_result, ( f"Replacement result does not match expected.\nExpected:\n{expected_result}\n\nGot:\n{result}" diff --git a/tests/test_languages/test_javascript_test_discovery.py b/tests/test_languages/test_javascript_test_discovery.py index 473bd330e..df697d482 100644 --- a/tests/test_languages/test_javascript_test_discovery.py +++ b/tests/test_languages/test_javascript_test_discovery.py @@ -29,7 +29,7 @@ def test_discover_tests_basic(self, js_support): # Create source file source_file = tmpdir / "math.js" source_file.write_text(""" -function add(a, b) { +export function add(a, b) { return a + b; } @@ -71,7 +71,7 @@ def test_discover_tests_spec_suffix(self, js_support): # Create source file source_file = tmpdir / "calculator.js" source_file.write_text(""" -function multiply(a, b) { +export function multiply(a, b) { return a * b; } @@ -103,7 +103,7 @@ def test_discover_tests_in_tests_directory(self, js_support): # Create source file source_file = tmpdir / "utils.js" source_file.write_text(""" -function formatDate(date) { +export function formatDate(date) { return date.toISOString(); } @@ -136,11 +136,11 @@ def test_discover_tests_nested_describe(self, js_support): source_file = tmpdir / "string_utils.js" source_file.write_text(""" -function capitalize(str) { +export function capitalize(str) { return str.charAt(0).toUpperCase() + str.slice(1); } -function lowercase(str) { +export function lowercase(str) { return str.toLowerCase(); } @@ -186,7 +186,7 @@ def test_discover_tests_with_it_block(self, js_support): source_file = tmpdir / "array_utils.js" source_file.write_text(""" -function sum(arr) { +export function sum(arr) { return arr.reduce((a, b) => a + b, 0); } @@ -254,7 +254,7 @@ def test_discover_tests_default_export(self, js_support): source_file = tmpdir / "greeter.js" source_file.write_text(""" -function greet(name) { +export function greet(name) { return `Hello, ${name}!`; } @@ -282,7 +282,7 @@ def test_discover_tests_class_methods(self, js_support): source_file = tmpdir / "calculator_class.js" source_file.write_text(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -333,7 +333,7 @@ def test_discover_tests_multi_level_directories(self, js_support): source_file = src_dir / "helpers.js" source_file.write_text(""" -function clamp(value, min, max) { +export function clamp(value, min, max) { return Math.min(Math.max(value, min), max); } @@ -375,11 +375,11 @@ def test_discover_tests_async_functions(self, js_support): source_file = tmpdir / "async_utils.js" source_file.write_text(""" -async function fetchData(url) { +export async function fetchData(url) { return await fetch(url).then(r => r.json()); } -async function delay(ms) { +export async function delay(ms) { return new Promise(resolve => setTimeout(resolve, ms)); } @@ -413,7 +413,7 @@ def test_discover_tests_jsx_component(self, js_support): source_file.write_text(""" import React from 'react'; -function Button({ onClick, children }) { +export function Button({ onClick, children }) { return ; } @@ -449,7 +449,7 @@ def test_discover_tests_no_matching_tests(self, js_support): source_file = tmpdir / "untested.js" source_file.write_text(""" -function untestedFunction() { +export function untestedFunction() { return 42; } @@ -479,11 +479,11 @@ def test_discover_tests_function_name_in_source(self, js_support): source_file = tmpdir / "validators.js" source_file.write_text(""" -function isEmail(str) { +export function isEmail(str) { return str.includes('@'); } -function isUrl(str) { +export function isUrl(str) { return str.startsWith('http'); } @@ -515,11 +515,11 @@ def test_discover_tests_multiple_test_files(self, js_support): source_file = tmpdir / "shared_utils.js" source_file.write_text(""" -function helper1() { +export function helper1() { return 1; } -function helper2() { +export function helper2() { return 2; } @@ -558,7 +558,7 @@ def test_discover_tests_template_literal_names(self, js_support): source_file = tmpdir / "format.js" source_file.write_text(""" -function formatNumber(n) { +export function formatNumber(n) { return n.toFixed(2); } @@ -587,7 +587,7 @@ def test_discover_tests_aliased_import(self, js_support): source_file = tmpdir / "transform.js" source_file.write_text(""" -function transformData(data) { +export function transformData(data) { return data.map(x => x * 2); } @@ -792,8 +792,8 @@ def test_require_named_import(self, js_support): source_file = tmpdir / "funcs.js" source_file.write_text(""" -function funcA() { return 1; } -function funcB() { return 2; } +export function funcA() { return 1; } +export function funcB() { return 2; } module.exports = { funcA, funcB }; """) @@ -846,7 +846,7 @@ def test_default_import(self, js_support): source_file = tmpdir / "default_export.js" source_file.write_text(""" -function mainFunc() { return 'main'; } +export function mainFunc() { return 'main'; } module.exports = mainFunc; """) @@ -875,7 +875,7 @@ def test_comments_in_test_file(self, js_support): source_file = tmpdir / "commented.js" source_file.write_text(""" -function compute() { return 42; } +export function compute() { return 42; } module.exports = { compute }; """) @@ -908,7 +908,7 @@ def test_test_file_with_syntax_error(self, js_support): source_file = tmpdir / "valid.js" source_file.write_text(""" -function validFunc() { return 1; } +export function validFunc() { return 1; } module.exports = { validFunc }; """) @@ -933,8 +933,8 @@ def test_function_with_same_name_as_jest_api(self, js_support): source_file = tmpdir / "conflict.js" source_file.write_text(""" -function test(value) { return value > 0; } -function describe(obj) { return JSON.stringify(obj); } +export function test(value) { return value > 0; } +export function describe(obj) { return JSON.stringify(obj); } module.exports = { test, describe }; """) @@ -962,7 +962,7 @@ def test_empty_test_directory(self, js_support): source_file = tmpdir / "lonely.js" source_file.write_text(""" -function lonelyFunc() { return 'alone'; } +export function lonelyFunc() { return 'alone'; } module.exports = { lonelyFunc }; """) @@ -980,14 +980,14 @@ def test_circular_imports(self, js_support): file_a = tmpdir / "moduleA.js" file_a.write_text(""" const { funcB } = require('./moduleB'); -function funcA() { return 'A' + (funcB ? funcB() : ''); } +export function funcA() { return 'A' + (funcB ? funcB() : ''); } module.exports = { funcA }; """) file_b = tmpdir / "moduleB.js" file_b.write_text(""" const { funcA } = require('./moduleA'); -function funcB() { return 'B'; } +export function funcB() { return 'B'; } module.exports = { funcB }; """) @@ -1126,17 +1126,17 @@ def test_full_discovery_workflow(self, js_support): # Source file source_file = src_dir / "utils.js" source_file.write_text(r""" -function validateEmail(email) { +export function validateEmail(email) { const re = /^[^\s@]+@[^\s@]+\.[^\s@]+$/; return re.test(email); } -function validatePhone(phone) { +export function validatePhone(phone) { const re = /^\d{10}$/; return re.test(phone); } -function formatName(first, last) { +export function formatName(first, last) { return `${first} ${last}`.trim(); } @@ -1197,7 +1197,7 @@ def test_discovery_with_fixtures(self, js_support): source_file = tmpdir / "database.js" source_file.write_text(""" -class Database { +export class Database { constructor() { this.data = []; } @@ -1259,13 +1259,13 @@ def test_test_file_imports_different_module(self, js_support): # Create two source files source_a = tmpdir / "moduleA.js" source_a.write_text(""" -function funcA() { return 'A'; } +export function funcA() { return 'A'; } module.exports = { funcA }; """) source_b = tmpdir / "moduleB.js" source_b.write_text(""" -function funcB() { return 'B'; } +export function funcB() { return 'B'; } module.exports = { funcB }; """) @@ -1296,9 +1296,9 @@ def test_test_file_imports_only_specific_function(self, js_support): source_file = tmpdir / "utils.js" source_file.write_text(""" -function funcOne() { return 1; } -function funcTwo() { return 2; } -function funcThree() { return 3; } +export function funcOne() { return 1; } +export function funcTwo() { return 2; } +export function funcThree() { return 3; } module.exports = { funcOne, funcTwo, funcThree }; """) @@ -1325,7 +1325,7 @@ def test_function_name_as_string_not_import(self, js_support): source_file = tmpdir / "target.js" source_file.write_text(""" -function targetFunc() { return 'target'; } +export function targetFunc() { return 'target'; } module.exports = { targetFunc }; """) @@ -1354,7 +1354,7 @@ def test_module_import_with_method_access(self, js_support): source_file = tmpdir / "math.js" source_file.write_text(""" -function calculate(x) { return x * 2; } +export function calculate(x) { return x * 2; } module.exports = { calculate }; """) @@ -1380,7 +1380,7 @@ def test_class_method_discovery_via_class_import(self, js_support): source_file = tmpdir / "myclass.js" source_file.write_text(""" -class MyClass { +export class MyClass { methodA() { return 'A'; } methodB() { return 'B'; } } @@ -1416,7 +1416,7 @@ def test_nested_module_structure(self, js_support): source_file = src_dir / "helpers.js" source_file.write_text(""" -function deepHelper() { return 'deep'; } +export function deepHelper() { return 'deep'; } module.exports = { deepHelper }; """) @@ -1574,9 +1574,9 @@ def test_multiple_functions_same_file_different_tests(self, js_support): source_file = tmpdir / "multiple.js" source_file.write_text(""" -function addNumbers(a, b) { return a + b; } -function subtractNumbers(a, b) { return a - b; } -function multiplyNumbers(a, b) { return a * b; } +export function addNumbers(a, b) { return a + b; } +export function subtractNumbers(a, b) { return a - b; } +export function multiplyNumbers(a, b) { return a * b; } module.exports = { addNumbers, subtractNumbers, multiplyNumbers }; """) @@ -1613,7 +1613,7 @@ def test_test_in_wrong_describe_still_discovered(self, js_support): source_file = tmpdir / "funcs.js" source_file.write_text(""" -function targetFunc() { return 'target'; } +export function targetFunc() { return 'target'; } module.exports = { targetFunc }; """) @@ -1705,7 +1705,7 @@ def test_class_method_qualified_name(self, js_support): source_file = tmpdir / "calculator.js" source_file.write_text(""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } subtract(a, b) { return a - b; } } @@ -1726,7 +1726,7 @@ def test_nested_class_method(self, js_support): source_file = tmpdir / "nested.js" source_file.write_text(""" -class Outer { +export class Outer { innerMethod() { class Inner { deepMethod() { return 'deep'; } diff --git a/tests/test_languages/test_javascript_test_runner.py b/tests/test_languages/test_javascript_test_runner.py index 87e712038..905ef24a8 100644 --- a/tests/test_languages/test_javascript_test_runner.py +++ b/tests/test_languages/test_javascript_test_runner.py @@ -668,10 +668,10 @@ def test_create_codeflash_jest_config(self): assert result_path.exists() assert result_path.name == "jest.codeflash.config.js" - # Verify it contains the tsconfig reference + # Verify it contains ESM package transformation patterns content = result_path.read_text() - assert "tsconfig.codeflash.json" in content - assert "ts-jest" in content + assert "transformIgnorePatterns" in content + assert "node_modules" in content def test_get_jest_config_for_project_with_bundler(self): """Test that bundler projects get codeflash Jest config.""" diff --git a/tests/test_languages/test_js_code_extractor.py b/tests/test_languages/test_js_code_extractor.py index b1dcee81f..a21f15e2e 100644 --- a/tests/test_languages/test_js_code_extractor.py +++ b/tests/test_languages/test_js_code_extractor.py @@ -109,12 +109,7 @@ def test_extract_context_includes_direct_helpers(self, js_support, cjs_project): factorial_helper = helper_dict["factorial"] expected_factorial_code = """\ -/** - * Calculate factorial recursively. - * @param n - Non-negative integer - * @returns Factorial of n - */ -function factorial(n) { +export function factorial(n) { // Intentionally inefficient recursive implementation if (n <= 1) return 1; return n * factorial(n - 1); @@ -196,46 +191,22 @@ def test_extract_compound_interest_helpers(self, js_support, cjs_project): # STRICT: Verify each helper's code exactly expected_add_code = """\ -/** - * Add two numbers. - * @param a - First number - * @param b - Second number - * @returns Sum of a and b - */ -function add(a, b) { +export function add(a, b) { return a + b; }""" expected_multiply_code = """\ -/** - * Multiply two numbers. - * @param a - First number - * @param b - Second number - * @returns Product of a and b - */ -function multiply(a, b) { +export function multiply(a, b) { return a * b; }""" expected_format_number_code = """\ -/** - * Format a number to specified decimal places. - * @param num - Number to format - * @param decimals - Number of decimal places - * @returns Formatted number - */ -function formatNumber(num, decimals) { +export function formatNumber(num, decimals) { return Number(num.toFixed(decimals)); }""" expected_validate_input_code = """\ -/** - * Validate that input is a valid number. - * @param value - Value to validate - * @param name - Parameter name for error message - * @throws Error if value is not a valid number - */ -function validateInput(value, name) { +export function validateInput(value, name) { if (typeof value !== 'number' || isNaN(value)) { throw new Error(`Invalid ${name}: must be a number`); } @@ -317,13 +288,7 @@ class Calculator { assert set(helper_dict.keys()) == {"add"}, f"Expected 'add' helper, got: {list(helper_dict.keys())}" expected_add_code = """\ -/** - * Add two numbers. - * @param a - First number - * @param b - Second number - * @returns Sum of a and b - */ -function add(a, b) { +export function add(a, b) { return a + b; }""" @@ -702,7 +667,7 @@ def js_support(self): def test_standalone_function(self, js_support, tmp_path): """Test standalone function with no helpers.""" source = """\ -function standalone(x) { +export function standalone(x) { return x * 2; } @@ -718,7 +683,7 @@ def test_standalone_function(self, js_support, tmp_path): # STRICT: Exact code comparison expected_code = """\ -function standalone(x) { +export function standalone(x) { return x * 2; }""" assert context.target_code.strip() == expected_code.strip(), ( @@ -735,7 +700,7 @@ def test_external_package_excluded(self, js_support, tmp_path): source = """\ const _ = require('lodash'); -function processArray(arr) { +export function processArray(arr) { return _.map(arr, x => x * 2); } @@ -750,7 +715,7 @@ def test_external_package_excluded(self, js_support, tmp_path): context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path) expected_code = """\ -function processArray(arr) { +export function processArray(arr) { return _.map(arr, x => x * 2); }""" @@ -769,7 +734,7 @@ def test_external_package_excluded(self, js_support, tmp_path): def test_recursive_function(self, js_support, tmp_path): """Test recursive function doesn't list itself as helper.""" source = """\ -function fibonacci(n) { +export function fibonacci(n) { if (n <= 1) return n; return fibonacci(n - 1) + fibonacci(n - 2); } @@ -786,7 +751,7 @@ def test_recursive_function(self, js_support, tmp_path): # STRICT: Exact code comparison expected_code = """\ -function fibonacci(n) { +export function fibonacci(n) { if (n <= 1) return n; return fibonacci(n - 1) + fibonacci(n - 2); }""" @@ -803,7 +768,7 @@ def test_arrow_function_helper(self, js_support, tmp_path): source = """\ const helper = (x) => x * 2; -const processValue = (value) => { +export const processValue = (value) => { return helper(value) + 1; }; @@ -818,7 +783,7 @@ def test_arrow_function_helper(self, js_support, tmp_path): context = js_support.extract_code_context(function=func, project_root=tmp_path, module_root=tmp_path) expected_code = """\ -const processValue = (value) => { +export const processValue = (value) => { return helper(value) + 1; };""" @@ -854,7 +819,7 @@ def ts_support(self): def test_method_extraction_includes_constructor(self, js_support, tmp_path): """Test that extracting a class method includes the constructor.""" source = """\ -class Counter { +export class Counter { constructor(initial = 0) { this.count = initial; } @@ -894,7 +859,7 @@ class Counter { def test_method_extraction_class_without_constructor(self, js_support, tmp_path): """Test extracting a method from a class that has no constructor.""" source = """\ -class MathUtils { +export class MathUtils { add(a, b) { return a + b; } @@ -928,7 +893,7 @@ class MathUtils { def test_typescript_method_extraction_includes_fields(self, ts_support, tmp_path): """Test that TypeScript method extraction includes class fields.""" source = """\ -class User { +export class User { private name: string; public age: number; @@ -941,8 +906,6 @@ class User { return this.name; } } - -export { User }; """ test_file = tmp_path / "user.ts" test_file.write_text(source) @@ -974,7 +937,7 @@ class User { def test_typescript_fields_only_no_constructor(self, ts_support, tmp_path): """Test TypeScript class with fields but no constructor.""" source = """\ -class Config { +export class Config { readonly apiUrl: string = "https://api.example.com"; timeout: number = 5000; @@ -982,8 +945,6 @@ class Config { return this.apiUrl; } } - -export { Config }; """ test_file = tmp_path / "config.ts" test_file.write_text(source) @@ -1010,7 +971,7 @@ class Config { def test_constructor_with_jsdoc(self, js_support, tmp_path): """Test that constructor with JSDoc is fully extracted.""" source = """\ -class Logger { +export class Logger { /** * Create a new Logger instance. * @param {string} prefix - The prefix to use for log messages. @@ -1056,7 +1017,7 @@ class Logger { def test_static_method_includes_constructor(self, js_support, tmp_path): """Test that static method extraction also includes constructor context.""" source = """\ -class Factory { +export class Factory { constructor(config) { this.config = config; } @@ -1212,13 +1173,11 @@ def test_extract_same_file_interface_from_parameter(self, ts_support, tmp_path): y: number; } -function distance(p1: Point, p2: Point): number { +export function distance(p1: Point, p2: Point): number { const dx = p2.x - p1.x; const dy = p2.y - p1.y; return Math.sqrt(dx * dx + dy * dy); } - -export { distance }; """ test_file = tmp_path / "geometry.ts" test_file.write_text(source) @@ -1251,7 +1210,7 @@ def test_extract_same_file_enum_from_parameter(self, ts_support, tmp_path): FAILURE = 'failure', } -function processStatus(status: Status): string { +export function processStatus(status: Status): string { switch (status) { case Status.PENDING: return 'Processing...'; @@ -1261,8 +1220,6 @@ def test_extract_same_file_enum_from_parameter(self, ts_support, tmp_path): return 'Failed!'; } } - -export { processStatus }; """ test_file = tmp_path / "status.ts" test_file.write_text(source) @@ -1295,11 +1252,9 @@ def test_extract_same_file_type_alias_from_return_type(self, ts_support, tmp_pat success: boolean; }; -function compute(x: number): Result { +export function compute(x: number): Result { return { value: x * 2, success: true }; } - -export { compute }; """ test_file = tmp_path / "compute.ts" test_file.write_text(source) @@ -1331,7 +1286,7 @@ def test_extract_class_field_types(self, ts_support, tmp_path): retries: number; } -class Service { +export class Service { private config: Config; constructor(config: Config) { @@ -1342,8 +1297,6 @@ class Service { return this.config.timeout; } } - -export { Service }; """ test_file = tmp_path / "service.ts" test_file.write_text(source) @@ -1372,11 +1325,9 @@ class Service { def test_primitive_types_not_included(self, ts_support, tmp_path): """Test that primitive types (number, string, etc.) are not extracted.""" source = """\ -function add(a: number, b: number): number { +export function add(a: number, b: number): number { return a + b; } - -export { add }; """ test_file = tmp_path / "add.ts" test_file.write_text(source) @@ -1405,11 +1356,9 @@ def test_extract_multiple_types(self, ts_support, tmp_path): height: number; } -function createRect(origin: Point, size: Size): { origin: Point; size: Size } { +export function createRect(origin: Point, size: Size): { origin: Point; size: Size } { return { origin, size }; } - -export { createRect }; """ test_file = tmp_path / "rect.ts" test_file.write_text(source) @@ -1447,7 +1396,7 @@ def test_extract_imported_type_definition(self, ts_support, ts_types_project): geometry_file.write_text("""\ import { Point, CalculationConfig } from './types'; -function calculateDistance(p1: Point, p2: Point, config: CalculationConfig): number { +export function calculateDistance(p1: Point, p2: Point, config: CalculationConfig): number { const dx = p2.x - p1.x; const dy = p2.y - p1.y; const distance = Math.sqrt(dx * dx + dy * dy); @@ -1458,8 +1407,6 @@ def test_extract_imported_type_definition(self, ts_support, ts_types_project): } return distance; } - -export { calculateDistance }; """) functions = ts_support.discover_functions(geometry_file) @@ -1506,11 +1453,9 @@ def test_type_with_jsdoc_included(self, ts_support, tmp_path): name: string; } -function greetUser(user: User): string { +export function greetUser(user: User): string { return `Hello, ${user.name}!`; } - -export { greetUser }; """ test_file = tmp_path / "user.ts" test_file.write_text(source) diff --git a/tests/test_languages/test_js_code_replacer.py b/tests/test_languages/test_js_code_replacer.py index c5b2cc001..9e251804a 100644 --- a/tests/test_languages/test_js_code_replacer.py +++ b/tests/test_languages/test_js_code_replacer.py @@ -757,7 +757,7 @@ class TestSimpleFunctionReplacement: def test_replace_simple_function_body(self, js_support, temp_project): """Test replacing a simple function body preserves structure exactly.""" original_source = """\ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -769,7 +769,7 @@ def test_replace_simple_function_body(self, js_support, temp_project): # Optimized version with different body optimized_code = """\ -function add(a, b) { +export function add(a, b) { // Optimized: direct return return a + b; } @@ -778,7 +778,7 @@ def test_replace_simple_function_body(self, js_support, temp_project): result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function add(a, b) { +export function add(a, b) { // Optimized: direct return return a + b; } @@ -789,7 +789,7 @@ def test_replace_simple_function_body(self, js_support, temp_project): def test_replace_function_with_multiple_statements(self, js_support, temp_project): """Test replacing function with complex multi-statement body.""" original_source = """\ -function processData(data) { +export function processData(data) { const result = []; for (let i = 0; i < data.length; i++) { result.push(data[i] * 2); @@ -805,7 +805,7 @@ def test_replace_function_with_multiple_statements(self, js_support, temp_projec # Optimized version using map optimized_code = """\ -function processData(data) { +export function processData(data) { return data.map(x => x * 2); } """ @@ -813,7 +813,7 @@ def test_replace_function_with_multiple_statements(self, js_support, temp_projec result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function processData(data) { +export function processData(data) { return data.map(x => x * 2); } """ @@ -825,12 +825,12 @@ def test_replace_preserves_surrounding_code(self, js_support, temp_project): original_source = """\ const CONFIG = { debug: true }; -function targetFunction(x) { +export function targetFunction(x) { console.log(x); return x * 2; } -function otherFunction(y) { +export function otherFunction(y) { return y + 1; } @@ -843,7 +843,7 @@ def test_replace_preserves_surrounding_code(self, js_support, temp_project): target_func = next(f for f in functions if f.function_name == "targetFunction") optimized_code = """\ -function targetFunction(x) { +export function targetFunction(x) { return x << 1; } """ @@ -853,11 +853,11 @@ def test_replace_preserves_surrounding_code(self, js_support, temp_project): expected_result = """\ const CONFIG = { debug: true }; -function targetFunction(x) { +export function targetFunction(x) { return x << 1; } -function otherFunction(y) { +export function otherFunction(y) { return y + 1; } @@ -873,7 +873,7 @@ class TestClassMethodReplacement: def test_replace_class_method_body(self, js_support, temp_project): """Test replacing a class method body preserves class structure.""" original_source = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } @@ -896,7 +896,7 @@ class Calculator { # Optimized version provided in class context optimized_code = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } @@ -910,7 +910,7 @@ class Calculator { result = js_support.replace_function(original_source, add_method, optimized_code) expected_result = """\ -class Calculator { +export class Calculator { constructor(precision = 2) { this.precision = precision; } @@ -930,7 +930,7 @@ class Calculator { def test_replace_method_calling_sibling_methods(self, js_support, temp_project): """Test replacing method that calls other methods in same class.""" original_source = """\ -class DataProcessor { +export class DataProcessor { constructor() { this.cache = new Map(); } @@ -958,7 +958,7 @@ class DataProcessor { process_method = next(f for f in functions if f.function_name == "process") optimized_code = """\ -class DataProcessor { +export class DataProcessor { constructor() { this.cache = new Map(); } @@ -975,7 +975,7 @@ class DataProcessor { result = js_support.replace_function(original_source, process_method, optimized_code) expected_result = """\ -class DataProcessor { +export class DataProcessor { constructor() { this.cache = new Map(); } @@ -1008,7 +1008,7 @@ def test_replace_preserves_jsdoc_above_function(self, js_support, temp_project): * @param {number} b - Second number * @returns {number} The sum */ -function add(a, b) { +export function add(a, b) { const sum = a + b; return sum; } @@ -1020,13 +1020,7 @@ def test_replace_preserves_jsdoc_above_function(self, js_support, temp_project): func = functions[0] optimized_code = """\ -/** - * Calculates the sum of two numbers. - * @param {number} a - First number - * @param {number} b - Second number - * @returns {number} The sum - */ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -1040,7 +1034,7 @@ def test_replace_preserves_jsdoc_above_function(self, js_support, temp_project): * @param {number} b - Second number * @returns {number} The sum */ -function add(a, b) { +export function add(a, b) { return a + b; } """ @@ -1054,7 +1048,7 @@ def test_replace_class_method_with_jsdoc(self, js_support, temp_project): * A simple cache implementation. * @class Cache */ -class Cache { +export class Cache { constructor() { this.data = new Map(); } @@ -1103,7 +1097,7 @@ class Cache { * A simple cache implementation. * @class Cache */ -class Cache { +export class Cache { constructor() { this.data = new Map(); } @@ -1128,7 +1122,7 @@ class TestAsyncFunctionReplacement: def test_replace_async_function_body(self, js_support, temp_project): """Test replacing async function preserves async keyword.""" original_source = """\ -async function fetchData(url) { +export async function fetchData(url) { const response = await fetch(url); const data = await response.json(); return data; @@ -1141,7 +1135,7 @@ def test_replace_async_function_body(self, js_support, temp_project): func = functions[0] optimized_code = """\ -async function fetchData(url) { +export async function fetchData(url) { return (await fetch(url)).json(); } """ @@ -1149,7 +1143,7 @@ def test_replace_async_function_body(self, js_support, temp_project): result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -async function fetchData(url) { +export async function fetchData(url) { return (await fetch(url)).json(); } """ @@ -1159,7 +1153,7 @@ def test_replace_async_function_body(self, js_support, temp_project): def test_replace_async_class_method(self, js_support, temp_project): """Test replacing async class method.""" original_source = """\ -class ApiClient { +export class ApiClient { constructor(baseUrl) { this.baseUrl = baseUrl; } @@ -1198,7 +1192,7 @@ class ApiClient { result = js_support.replace_function(original_source, get_method, optimized_code) expected_result = """\ -class ApiClient { +export class ApiClient { constructor(baseUrl) { this.baseUrl = baseUrl; } @@ -1220,7 +1214,7 @@ class TestGeneratorFunctionReplacement: def test_replace_generator_function_body(self, js_support, temp_project): """Test replacing generator function preserves generator syntax.""" original_source = """\ -function* range(start, end) { +export function* range(start, end) { for (let i = start; i < end; i++) { yield i; } @@ -1233,7 +1227,7 @@ def test_replace_generator_function_body(self, js_support, temp_project): func = functions[0] optimized_code = """\ -function* range(start, end) { +export function* range(start, end) { let i = start; while (i < end) yield i++; } @@ -1242,7 +1236,7 @@ def test_replace_generator_function_body(self, js_support, temp_project): result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function* range(start, end) { +export function* range(start, end) { let i = start; while (i < end) yield i++; } @@ -1257,7 +1251,7 @@ class TestTypeScriptReplacement: def test_replace_typescript_function_with_types(self, ts_support, temp_project): """Test replacing TypeScript function preserves type annotations.""" original_source = """\ -function processArray(items: number[]): number { +export function processArray(items: number[]): number { let sum = 0; for (let i = 0; i < items.length; i++) { sum += items[i]; @@ -1272,7 +1266,7 @@ def test_replace_typescript_function_with_types(self, ts_support, temp_project): func = functions[0] optimized_code = """\ -function processArray(items: number[]): number { +export function processArray(items: number[]): number { return items.reduce((a, b) => a + b, 0); } """ @@ -1280,7 +1274,7 @@ def test_replace_typescript_function_with_types(self, ts_support, temp_project): result = ts_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function processArray(items: number[]): number { +export function processArray(items: number[]): number { return items.reduce((a, b) => a + b, 0); } """ @@ -1290,7 +1284,7 @@ def test_replace_typescript_function_with_types(self, ts_support, temp_project): def test_replace_typescript_class_method_with_generics(self, ts_support, temp_project): """Test replacing TypeScript generic class method.""" original_source = """\ -class Container { +export class Container { private items: T[] = []; add(item: T): void { @@ -1325,7 +1319,7 @@ class Container { result = ts_support.replace_function(original_source, get_all_method, optimized_code) expected_result = """\ -class Container { +export class Container { private items: T[] = []; add(item: T): void { @@ -1349,7 +1343,7 @@ def test_replace_typescript_interface_typed_function(self, ts_support, temp_proj email: string; } -function createUser(name: string, email: string): User { +export function createUser(name: string, email: string): User { const id = Math.random().toString(36).substring(2, 15); const user: User = { id: id, @@ -1366,7 +1360,7 @@ def test_replace_typescript_interface_typed_function(self, ts_support, temp_proj func = next(f for f in functions if f.function_name == "createUser") optimized_code = """\ -function createUser(name: string, email: string): User { +export function createUser(name: string, email: string): User { return { id: Math.random().toString(36).substring(2, 15), name, @@ -1384,7 +1378,7 @@ def test_replace_typescript_interface_typed_function(self, ts_support, temp_proj email: string; } -function createUser(name: string, email: string): User { +export function createUser(name: string, email: string): User { return { id: Math.random().toString(36).substring(2, 15), name, @@ -1402,7 +1396,7 @@ class TestComplexReplacements: def test_replace_function_with_nested_functions(self, js_support, temp_project): """Test replacing function that contains nested function definitions.""" original_source = """\ -function processItems(items) { +export function processItems(items) { function helper(item) { return item * 2; } @@ -1421,7 +1415,7 @@ def test_replace_function_with_nested_functions(self, js_support, temp_project): process_func = next(f for f in functions if f.function_name == "processItems") optimized_code = """\ -function processItems(items) { +export function processItems(items) { const helper = x => x * 2; return items.map(helper); } @@ -1430,7 +1424,7 @@ def test_replace_function_with_nested_functions(self, js_support, temp_project): result = js_support.replace_function(original_source, process_func, optimized_code) expected_result = """\ -function processItems(items) { +export function processItems(items) { const helper = x => x * 2; return items.map(helper); } @@ -1441,7 +1435,7 @@ def test_replace_function_with_nested_functions(self, js_support, temp_project): def test_replace_multiple_methods_sequentially(self, js_support, temp_project): """Test replacing multiple methods in the same class sequentially.""" original_source = """\ -class MathUtils { +export class MathUtils { static sum(arr) { let total = 0; for (let i = 0; i < arr.length; i++) { @@ -1478,7 +1472,7 @@ class MathUtils { result = js_support.replace_function(original_source, sum_method, optimized_sum) expected_after_first = """\ -class MathUtils { +export class MathUtils { static sum(arr) { return arr.reduce((a, b) => a + b, 0); } @@ -1499,7 +1493,7 @@ class MathUtils { def test_replace_function_with_complex_destructuring(self, js_support, temp_project): """Test replacing function with complex parameter destructuring.""" original_source = """\ -function processConfig({ server: { host, port }, database: { url, poolSize } }) { +export function processConfig({ server: { host, port }, database: { url, poolSize } }) { const serverUrl = host + ':' + port; const dbConnection = url + '?poolSize=' + poolSize; return { @@ -1515,7 +1509,7 @@ def test_replace_function_with_complex_destructuring(self, js_support, temp_proj func = functions[0] optimized_code = """\ -function processConfig({ server: { host, port }, database: { url, poolSize } }) { +export function processConfig({ server: { host, port }, database: { url, poolSize } }) { return { server: `${host}:${port}`, db: `${url}?poolSize=${poolSize}` @@ -1526,7 +1520,7 @@ def test_replace_function_with_complex_destructuring(self, js_support, temp_proj result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function processConfig({ server: { host, port }, database: { url, poolSize } }) { +export function processConfig({ server: { host, port }, database: { url, poolSize } }) { return { server: `${host}:${port}`, db: `${url}?poolSize=${poolSize}` @@ -1543,7 +1537,7 @@ class TestEdgeCases: def test_replace_minimal_function_body(self, js_support, temp_project): """Test replacing function with minimal body.""" original_source = """\ -function minimal() { +export function minimal() { return null; } """ @@ -1554,7 +1548,7 @@ def test_replace_minimal_function_body(self, js_support, temp_project): func = functions[0] optimized_code = """\ -function minimal() { +export function minimal() { return { initialized: true, timestamp: Date.now() }; } """ @@ -1562,7 +1556,7 @@ def test_replace_minimal_function_body(self, js_support, temp_project): result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function minimal() { +export function minimal() { return { initialized: true, timestamp: Date.now() }; } """ @@ -1572,7 +1566,7 @@ def test_replace_minimal_function_body(self, js_support, temp_project): def test_replace_single_line_function(self, js_support, temp_project): """Test replacing single-line function.""" original_source = """\ -function identity(x) { return x; } +export function identity(x) { return x; } """ file_path = temp_project / "utils.js" file_path.write_text(original_source, encoding="utf-8") @@ -1581,13 +1575,13 @@ def test_replace_single_line_function(self, js_support, temp_project): func = functions[0] optimized_code = """\ -function identity(x) { return x ?? null; } +export function identity(x) { return x ?? null; } """ result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function identity(x) { return x ?? null; } +export function identity(x) { return x ?? null; } """ assert result == expected_result assert js_support.validate_syntax(result) is True @@ -1595,7 +1589,7 @@ def test_replace_single_line_function(self, js_support, temp_project): def test_replace_function_with_special_characters_in_strings(self, js_support, temp_project): """Test replacing function containing special characters in strings.""" original_source = """\ -function formatMessage(name) { +export function formatMessage(name) { const greeting = 'Hello, ' + name + '!'; const special = "Contains \\"quotes\\" and \\n newlines"; return greeting + ' ' + special; @@ -1608,7 +1602,7 @@ def test_replace_function_with_special_characters_in_strings(self, js_support, t func = functions[0] optimized_code = """\ -function formatMessage(name) { +export function formatMessage(name) { return `Hello, ${name}! Contains "quotes" and newlines`; } @@ -1617,7 +1611,7 @@ def test_replace_function_with_special_characters_in_strings(self, js_support, t result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function formatMessage(name) { +export function formatMessage(name) { return `Hello, ${name}! Contains "quotes" and newlines`; } @@ -1628,7 +1622,7 @@ def test_replace_function_with_special_characters_in_strings(self, js_support, t def test_replace_function_with_regex(self, js_support, temp_project): """Test replacing function containing regex patterns.""" original_source = """\ -function validateEmail(email) { +export function validateEmail(email) { const pattern = /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$/; if (pattern.test(email)) { return true; @@ -1643,7 +1637,7 @@ def test_replace_function_with_regex(self, js_support, temp_project): func = functions[0] optimized_code = """\ -function validateEmail(email) { +export function validateEmail(email) { return /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$/.test(email); } """ @@ -1651,7 +1645,7 @@ def test_replace_function_with_regex(self, js_support, temp_project): result = js_support.replace_function(original_source, func, optimized_code) expected_result = """\ -function validateEmail(email) { +export function validateEmail(email) { return /^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$/.test(email); } """ @@ -1665,11 +1659,11 @@ class TestModuleExportHandling: def test_replace_exported_function_commonjs(self, js_support, temp_project): """Test replacing function in CommonJS module preserves exports.""" original_source = """\ -function helper(x) { +export function helper(x) { return x * 2; } -function main(data) { +export function main(data) { const results = []; for (let i = 0; i < data.length; i++) { results.push(helper(data[i])); @@ -1686,7 +1680,7 @@ def test_replace_exported_function_commonjs(self, js_support, temp_project): main_func = next(f for f in functions if f.function_name == "main") optimized_code = """\ -function main(data) { +export function main(data) { return data.map(helper); } """ @@ -1694,11 +1688,11 @@ def test_replace_exported_function_commonjs(self, js_support, temp_project): result = js_support.replace_function(original_source, main_func, optimized_code) expected_result = """\ -function helper(x) { +export function helper(x) { return x * 2; } -function main(data) { +export function main(data) { return data.map(helper); } @@ -1757,18 +1751,18 @@ def test_all_replacements_produce_valid_syntax(self, js_support, temp_project): test_cases = [ # (original, optimized, description) ( - "function f(x) { return x + 1; }", - "function f(x) { return ++x; }", + "export function f(x) { return x + 1; }", + "export function f(x) { return ++x; }", "increment replacement" ), ( - "function f(arr) { return arr.length > 0; }", - "function f(arr) { return !!arr.length; }", + "export function f(arr) { return arr.length > 0; }", + "export function f(arr) { return !!arr.length; }", "boolean conversion" ), ( - "function f(a, b) { if (a) { return a; } return b; }", - "function f(a, b) { return a || b; }", + "export function f(a, b) { if (a) { return a; } return b; }", + "export function f(a, b) { return a || b; }", "logical OR replacement" ), ] diff --git a/tests/test_languages/test_language_parity.py b/tests/test_languages/test_language_parity.py index ae57eb426..2b2035c84 100644 --- a/tests/test_languages/test_language_parity.py +++ b/tests/test_languages/test_language_parity.py @@ -38,7 +38,7 @@ def add(a, b): return a + b """, javascript=""" -function add(a, b) { +export function add(a, b) { return a + b; } """, @@ -58,15 +58,15 @@ def multiply(a, b): return a * b """, javascript=""" -function add(a, b) { +export function add(a, b) { return a + b; } -function subtract(a, b) { +export function subtract(a, b) { return a - b; } -function multiply(a, b) { +export function multiply(a, b) { return a * b; } """, @@ -83,11 +83,11 @@ def without_return(): print("hello") """, javascript=""" -function withReturn() { +export function withReturn() { return 1; } -function withoutReturn() { +export function withoutReturn() { console.log("hello"); } """, @@ -105,7 +105,7 @@ def multiply(self, a, b): return a * b """, javascript=""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -128,11 +128,11 @@ def sync_function(): return 1 """, javascript=""" -async function fetchData(url) { +export async function fetchData(url) { return await fetch(url); } -function syncFunction() { +export function syncFunction() { return 1; } """, @@ -148,7 +148,7 @@ def inner(): return inner() """, javascript=""" -function outer() { +export function outer() { function inner() { return 1; } @@ -167,7 +167,7 @@ def helper(x): return x * 2 """, javascript=""" -class Utils { +export class Utils { static helper(x) { return x * 2; } @@ -194,7 +194,7 @@ def standalone(): return 42 """, javascript=""" -class Calculator { +export class Calculator { add(a, b) { return a + b; } @@ -204,13 +204,13 @@ class Calculator { } } -class StringUtils { +export class StringUtils { reverse(s) { return s.split('').reverse().join(''); } } -function standalone() { +export function standalone() { return 42; } """, @@ -227,11 +227,11 @@ def sync_func(): return 2 """, javascript=""" -async function asyncFunc() { +export async function asyncFunc() { return 1; } -function syncFunc() { +export function syncFunc() { return 2; } """, @@ -249,11 +249,11 @@ def method(self): return 2 """, javascript=""" -function standalone() { +export function standalone() { return 1; } -class MyClass { +export class MyClass { method() { return 2; } @@ -906,7 +906,7 @@ def test_discover_and_replace_workflow(self, python_support, js_support): return n return fibonacci(n - 1) + fibonacci(n - 2) """ - js_original = """function fibonacci(n) { + js_original = """export function fibonacci(n) { if (n <= 1) { return n; } @@ -933,7 +933,7 @@ def test_discover_and_replace_workflow(self, python_support, js_support): memo[i] = memo[i-1] + memo[i-2] return memo[n] """ - js_optimized = """function fibonacci(n) { + js_optimized = """export function fibonacci(n) { // Memoized version const memo = {0: 0, 1: 1}; for (let i = 2; i <= n; i++) { @@ -994,13 +994,13 @@ def test_function_info_fields_populated(self, python_support, js_support): def test_arrow_functions_unique_to_js(self, js_support): """JavaScript arrow functions should be discovered (no Python equivalent).""" js_code = """ -const add = (a, b) => { +export const add = (a, b) => { return a + b; }; -const multiply = (x, y) => x * y; +export const multiply = (x, y) => x * y; -const identity = x => x; +export const identity = x => x; """ js_file = write_temp_file(js_code, ".js") funcs = js_support.discover_functions(js_file) @@ -1021,7 +1021,7 @@ def number_generator(): return 3 """ js_code = """ -function* numberGenerator() { +export function* numberGenerator() { yield 1; yield 2; return 3; @@ -1065,11 +1065,11 @@ def multi_decorated(): def test_function_expressions_js(self, js_support): """JavaScript function expressions should be discovered.""" js_code = """ -const add = function(a, b) { +export const add = function(a, b) { return a + b; }; -const namedExpr = function myFunc(x) { +export const namedExpr = function myFunc(x) { return x * 2; }; """ @@ -1132,7 +1132,7 @@ def greeting(): return "Hello, δΈ–η•Œ! 🌍" """ js_code = """ -function greeting() { +export function greeting() { return "Hello, δΈ–η•Œ! 🌍"; } """ diff --git a/tests/test_languages/test_typescript_code_extraction.py b/tests/test_languages/test_typescript_code_extraction.py index f97049943..b344a2492 100644 --- a/tests/test_languages/test_typescript_code_extraction.py +++ b/tests/test_languages/test_typescript_code_extraction.py @@ -119,7 +119,7 @@ def test_extract_simple_function(self, ts_support): """Test extracting code context for a simple function.""" with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: f.write(""" -function add(a: number, b: number): number { +export function add(a: number, b: number): number { return a + b; } """) @@ -147,7 +147,7 @@ def test_extract_async_function_with_template_literal(self, ts_support): const command_args = process.argv.slice(3); -async function execMongoEval(queryExpression, appsmithMongoURI) { +export async function execMongoEval(queryExpression, appsmithMongoURI) { queryExpression = queryExpression.trim(); if (command_args.includes("--pretty")) { @@ -186,7 +186,7 @@ def test_extract_function_with_complex_try_catch(self, ts_support): import fsPromises from "fs/promises"; import path from "path"; -async function figureOutContentsPath(root: string): Promise { +export async function figureOutContentsPath(root: string): Promise { const subfolders = await fsPromises.readdir(root, { withFileTypes: true }); try { @@ -238,7 +238,7 @@ def test_extracted_code_includes_imports(self, ts_support): import fs from "fs"; import path from "path"; -function readConfig(filename: string): string { +export function readConfig(filename: string): string { const fullPath = path.join(__dirname, filename); return fs.readFileSync(fullPath, "utf8"); } @@ -264,7 +264,7 @@ def test_extracted_code_includes_global_variables(self, ts_support): const CONFIG = { timeout: 5000 }; const MAX_RETRIES = 3; -async function fetchWithRetry(url: string): Promise { +export async function fetchWithRetry(url: string): Promise { for (let i = 0; i < MAX_RETRIES; i++) { try { const response = await fetch(url, { signal: AbortSignal.timeout(CONFIG.timeout) }); @@ -289,6 +289,164 @@ def test_extracted_code_includes_global_variables(self, ts_support): assert ts_support.validate_syntax(code_context.target_code) is True +class TestSameClassHelperExtraction: + """Tests for same-class helper method extraction. + + When a class method calls other methods from the same class, those helper + methods should be included inside the class wrapper (not appended outside), + because they may use class-specific syntax like 'private'. + """ + + def test_private_helper_method_inside_class_wrapper(self, ts_support): + """Test that private helper methods are included inside the class wrapper.""" + with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: + # Export the class and add return statements so discover_functions finds the methods + f.write(""" +export class EndpointGroup { + private endpoints: any[] = []; + + constructor() { + this.endpoints = []; + } + + post(path: string, handler: Function): EndpointGroup { + this.addEndpoint("POST", path, handler); + return this; + } + + private addEndpoint(method: string, path: string, handler: Function): void { + this.endpoints.push({ method, path, handler }); + return; + } +} +""") + f.flush() + file_path = Path(f.name) + + # Discover the 'post' method + functions = ts_support.discover_functions(file_path) + post_method = None + for func in functions: + if func.function_name == "post": + post_method = func + break + + assert post_method is not None, "post method should be discovered" + + # Extract code context + code_context = ts_support.extract_code_context( + post_method, file_path.parent, file_path.parent + ) + + # The extracted code should be syntactically valid + assert ts_support.validate_syntax(code_context.target_code) is True, ( + f"Extracted code should be valid TypeScript:\n{code_context.target_code}" + ) + + # Both post and addEndpoint should be inside the class + assert "class EndpointGroup" in code_context.target_code + assert "post(" in code_context.target_code + assert "private addEndpoint" in code_context.target_code + + # The private method should be inside the class, not outside + # Check that addEndpoint appears BEFORE the closing brace of the class + class_end_index = code_context.target_code.rfind("}") + add_endpoint_index = code_context.target_code.find("addEndpoint") + assert add_endpoint_index < class_end_index, ( + "addEndpoint should be inside the class wrapper" + ) + + def test_multiple_private_helpers_inside_class(self, ts_support): + """Test that multiple private helpers are all included inside the class.""" + with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: + f.write(""" +export class Router { + private routes: Map = new Map(); + + addRoute(path: string, handler: Function): boolean { + const normalizedPath = this.normalizePath(path); + this.validatePath(normalizedPath); + this.routes.set(normalizedPath, handler); + return true; + } + + private normalizePath(path: string): string { + return path.toLowerCase().trim(); + } + + private validatePath(path: string): boolean { + if (!path.startsWith("/")) { + throw new Error("Path must start with /"); + } + return true; + } +} +""") + f.flush() + file_path = Path(f.name) + + # Discover the 'addRoute' method + functions = ts_support.discover_functions(file_path) + add_route_method = None + for func in functions: + if func.function_name == "addRoute": + add_route_method = func + break + + assert add_route_method is not None + + code_context = ts_support.extract_code_context( + add_route_method, file_path.parent, file_path.parent + ) + + # Should be valid TypeScript + assert ts_support.validate_syntax(code_context.target_code) is True + + # All methods should be inside the class + assert "private normalizePath" in code_context.target_code + assert "private validatePath" in code_context.target_code + + def test_same_class_helpers_filtered_from_helper_list(self, ts_support): + """Test that same-class helpers are not duplicated in the helpers list.""" + with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: + f.write(""" +export class Calculator { + add(a: number, b: number): number { + return this.compute(a, b, "+"); + } + + private compute(a: number, b: number, op: string): number { + if (op === "+") return a + b; + return 0; + } +} +""") + f.flush() + file_path = Path(f.name) + + functions = ts_support.discover_functions(file_path) + add_method = None + for func in functions: + if func.function_name == "add": + add_method = func + break + + assert add_method is not None + + code_context = ts_support.extract_code_context( + add_method, file_path.parent, file_path.parent + ) + + # 'compute' should be in target_code (inside class) + assert "compute" in code_context.target_code + + # 'compute' should NOT be in helper_functions (would be duplicate) + helper_names = [h.name for h in code_context.helper_functions] + assert "compute" not in helper_names, ( + "Same-class helper 'compute' should not be in helper_functions list" + ) + + class TestTypeScriptLanguageProperties: """Tests for TypeScript language support properties.""" diff --git a/tests/test_languages/test_typescript_e2e.py b/tests/test_languages/test_typescript_e2e.py index 199094a1d..a638f01a1 100644 --- a/tests/test_languages/test_typescript_e2e.py +++ b/tests/test_languages/test_typescript_e2e.py @@ -285,7 +285,7 @@ def test_function_to_optimize_has_correct_fields(self): with tempfile.NamedTemporaryFile(suffix=".ts", mode="w", delete=False) as f: f.write(""" -class Calculator { +export class Calculator { add(a: number, b: number): number { return a + b; } @@ -295,7 +295,7 @@ class Calculator { } } -function standalone(x: number): number { +export function standalone(x: number): number { return x * 2; } """)