From 83113d1051d04ea5967052b226d7322a096a93da Mon Sep 17 00:00:00 2001 From: Sarthak Agarwal Date: Tue, 3 Feb 2026 03:56:07 +0530 Subject: [PATCH] alias support and vitest imports --- codeflash/cli_cmds/cli.py | 5 +- .../languages/javascript/import_resolver.py | 52 +++++++++++++- .../languages/javascript/module_system.py | 67 +++++++++++++++++++ codeflash/languages/javascript/support.py | 4 +- codeflash/verification/verification_utils.py | 7 +- codeflash/verification/verifier.py | 8 ++- 6 files changed, 136 insertions(+), 7 deletions(-) diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 11a9bf02e..b20f1fbcf 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -288,8 +288,9 @@ def process_pyproject_config(args: Namespace) -> Namespace: normalized_ignore_paths = [] for path in args.ignore_paths: path_obj = Path(path) - assert path_obj.exists(), f"ignore-paths config must be a valid path. Path {path} does not exist" - normalized_ignore_paths.append(path_obj.resolve()) + if path_obj.exists(): + normalized_ignore_paths.append(path_obj.resolve()) + # Silently skip non-existent paths (e.g., .next, dist before build) args.ignore_paths = normalized_ignore_paths # Project root path is one level above the specified directory, because that's where the module can be imported from args.module_root = Path(args.module_root).resolve() diff --git a/codeflash/languages/javascript/import_resolver.py b/codeflash/languages/javascript/import_resolver.py index 49452ec51..537132d96 100644 --- a/codeflash/languages/javascript/import_resolver.py +++ b/codeflash/languages/javascript/import_resolver.py @@ -124,6 +124,15 @@ def _resolve_module_path(self, module_path: str, source_dir: Path) -> Path | Non if module_path.startswith("/"): return self._resolve_absolute_import(module_path) + # Handle @/ path alias (common in Next.js/TypeScript projects) + # @/ maps to the project root + if module_path.startswith("@/"): + return self._resolve_path_alias(module_path[2:]) # Strip @/ + + # Handle ~/ path alias (another common pattern) + if module_path.startswith("~/"): + return self._resolve_path_alias(module_path[2:]) # Strip ~/ + # Bare imports (e.g., 'lodash') are external packages return None @@ -195,6 +204,38 @@ def _resolve_absolute_import(self, module_path: str) -> Path | None: return None + def _resolve_path_alias(self, module_path: str) -> Path | None: + """Resolve path alias imports like @/utils or ~/lib/helper. + + Args: + module_path: The import path without the alias prefix. + + Returns: + Resolved absolute path, or None if not found. + + """ + # Treat as relative to project root + base_path = (self.project_root / module_path).resolve() + + # Check if path is within project + try: + base_path.relative_to(self.project_root) + except ValueError: + logger.debug("Path alias resolves outside project root: %s", base_path) + return None + + # Try adding extensions + resolved = self._try_extensions(base_path) + if resolved: + return resolved + + # Try as directory with index file + resolved = self._try_index_file(base_path) + if resolved: + return resolved + + return None + def _try_extensions(self, base_path: Path) -> Path | None: """Try adding various extensions to find the actual file. @@ -265,10 +306,19 @@ def _is_external_package(self, module_path: str) -> bool: if module_path.startswith("/"): return False + # @/ is a common path alias in Next.js/TypeScript projects mapping to project root + # These are internal imports, not external npm packages + if module_path.startswith("@/"): + return False + + # ~/ is another common path alias pattern + if module_path.startswith("~/"): + return False + # Bare imports without ./ or ../ are external packages # This includes: # - 'lodash' - # - '@company/utils' + # - '@company/utils' (scoped npm packages) # - 'react' # - 'fs' (Node.js built-ins) return True diff --git a/codeflash/languages/javascript/module_system.py b/codeflash/languages/javascript/module_system.py index e88325d24..4e4e3bb0c 100644 --- a/codeflash/languages/javascript/module_system.py +++ b/codeflash/languages/javascript/module_system.py @@ -330,3 +330,70 @@ def ensure_module_system_compatibility(code: str, target_module_system: str) -> return convert_esm_to_commonjs(code) return code + + +def ensure_vitest_imports(code: str, test_framework: str) -> str: + """Ensure vitest test globals are imported when using vitest framework. + + Vitest by default does not enable globals (describe, test, expect, etc.), + so they must be explicitly imported. This function adds the import if missing. + + Args: + code: JavaScript/TypeScript test code. + test_framework: The test framework being used (vitest, jest, mocha). + + Returns: + Code with vitest imports added if needed. + + """ + if test_framework != "vitest": + return code + + # Check if vitest imports already exist + if "from 'vitest'" in code or 'from "vitest"' in code: + return code + + # Check if the code uses test functions that need to be imported + test_globals = ["describe", "test", "it", "expect", "vi", "beforeEach", "afterEach", "beforeAll", "afterAll"] + needs_import = any(f"{global_name}(" in code or f"{global_name} (" in code for global_name in test_globals) + + if not needs_import: + return code + + # Determine which globals are actually used in the code + used_globals = [g for g in test_globals if f"{g}(" in code or f"{g} (" in code] + if not used_globals: + return code + + # Build the import statement + import_statement = f"import {{ {', '.join(used_globals)} }} from 'vitest';\n" + + # Find the first line that isn't a comment or empty + lines = code.split("\n") + insert_index = 0 + for i, line in enumerate(lines): + stripped = line.strip() + if stripped and not stripped.startswith("//") and not stripped.startswith("/*") and not stripped.startswith("*"): + # Check if this line is an import/require - insert after imports + if stripped.startswith("import ") or stripped.startswith("const ") or stripped.startswith("let "): + continue + insert_index = i + break + insert_index = i + 1 + + # Find the last import line to insert after it + last_import_index = -1 + for i, line in enumerate(lines): + stripped = line.strip() + if stripped.startswith("import ") and "from " in stripped: + last_import_index = i + + if last_import_index >= 0: + # Insert after the last import + lines.insert(last_import_index + 1, import_statement.rstrip()) + else: + # Insert at the beginning (after any leading comments) + lines.insert(insert_index, import_statement.rstrip()) + + logger.debug("Added vitest imports: %s", used_globals) + return "\n".join(lines) diff --git a/codeflash/languages/javascript/support.py b/codeflash/languages/javascript/support.py index 0e77c2c5e..5a68761f6 100644 --- a/codeflash/languages/javascript/support.py +++ b/codeflash/languages/javascript/support.py @@ -62,7 +62,9 @@ def default_file_extension(self) -> str: @property def test_framework(self) -> str: """Primary test framework for JavaScript.""" - return "jest" + from codeflash.languages.test_framework import get_js_test_framework_or_default + + return get_js_test_framework_or_default() @property def comment_prefix(self) -> str: diff --git a/codeflash/verification/verification_utils.py b/codeflash/verification/verification_utils.py index 53dd6c80b..6bbe36fc6 100644 --- a/codeflash/verification/verification_utils.py +++ b/codeflash/verification/verification_utils.py @@ -86,10 +86,13 @@ class TestConfig: def test_framework(self) -> str: """Returns the appropriate test framework based on language. - Returns 'jest' for JavaScript/TypeScript, 'pytest' for Python (default). + For JavaScript/TypeScript: uses the configured framework (vitest, jest, or mocha). + For Python: uses pytest as default. """ if is_javascript(): - return "jest" + from codeflash.languages.test_framework import get_js_test_framework_or_default + + return get_js_test_framework_or_default() return "pytest" def set_language(self, language: str) -> None: diff --git a/codeflash/verification/verifier.py b/codeflash/verification/verifier.py index 19500b968..b3f66ee50 100644 --- a/codeflash/verification/verifier.py +++ b/codeflash/verification/verifier.py @@ -69,7 +69,10 @@ def generate_tests( instrument_generated_js_test, validate_and_fix_import_style, ) - from codeflash.languages.javascript.module_system import ensure_module_system_compatibility + from codeflash.languages.javascript.module_system import ( + ensure_module_system_compatibility, + ensure_vitest_imports, + ) source_file = Path(function_to_optimize.file_path) func_name = function_to_optimize.function_name @@ -81,6 +84,9 @@ def generate_tests( # Convert module system if needed (e.g., CommonJS -> ESM for ESM projects) generated_test_source = ensure_module_system_compatibility(generated_test_source, project_module_system) + # Ensure vitest imports are present when using vitest framework + generated_test_source = ensure_vitest_imports(generated_test_source, test_cfg.test_framework) + # Instrument for behavior verification (writes to SQLite) instrumented_behavior_test_source = instrument_generated_js_test( test_code=generated_test_source,