From 218beb66a29340acf26b4f032a86147a5a489073 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Tue, 24 Feb 2026 22:42:01 +0530 Subject: [PATCH 1/2] perf: fix three performance bottlenecks in optimizer, comparator, and API client Cache normalized imported modules across prepare_module_for_optimization calls to eliminate redundant file I/O and AST parsing for shared imports. Move conditional library imports in comparator.py from inside the recursive comparator() function to module level to avoid per-call import machinery overhead. Use a module-level requests.Session in cfapi.py for HTTP connection pooling instead of creating new TCP/TLS connections per request. Co-Authored-By: Claude Opus 4.6 --- codeflash/api/cfapi.py | 6 +- codeflash/optimization/optimizer.py | 10 ++- codeflash/verification/comparator.py | 96 ++++++++++++++++++++-------- 3 files changed, 81 insertions(+), 31 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index f7957fa0d..900d0f1e8 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -27,6 +27,8 @@ from packaging import version +cfapi_session = requests.Session() + @dataclass class BaseUrls: @@ -77,9 +79,9 @@ def make_cfapi_request( if method.upper() == "POST": json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) cfapi_headers["Content-Type"] = "application/json" - response = requests.post(url, data=json_payload, headers=cfapi_headers, timeout=60) + response = cfapi_session.post(url, data=json_payload, headers=cfapi_headers, timeout=60) else: - response = requests.get(url, headers=cfapi_headers, params=params, timeout=60) + response = cfapi_session.get(url, headers=cfapi_headers, params=params, timeout=60) response.raise_for_status() return response except requests.exceptions.HTTPError: diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 3211ab59b..639a074b1 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -70,6 +70,7 @@ def __init__(self, args: Namespace) -> None: self.current_worktree: Path | None = None self.original_args_and_test_cfg: tuple[Namespace, TestConfig] | None = None self.patch_files: list[Path] = [] + self.normalized_imports_cache: dict[Path, ValidCode] = {} @staticmethod def _find_js_project_root(file_path: Path) -> Path | None: @@ -328,6 +329,9 @@ def prepare_module_for_optimization( has_syntax_error = False for analysis in imported_module_analyses: + if analysis.file_path in self.normalized_imports_cache: + validated_original_code[analysis.file_path] = self.normalized_imports_cache[analysis.file_path] + continue callee_original_code = analysis.file_path.read_text(encoding="utf8") try: normalized_callee_original_code = normalize_code(callee_original_code) @@ -336,9 +340,9 @@ def prepare_module_for_optimization( logger.info("Skipping optimization due to helper file error.") has_syntax_error = True break - validated_original_code[analysis.file_path] = ValidCode( - source_code=callee_original_code, normalized_code=normalized_callee_original_code - ) + valid_code = ValidCode(source_code=callee_original_code, normalized_code=normalized_callee_original_code) + self.normalized_imports_cache[analysis.file_path] = valid_code + validated_original_code[analysis.file_path] = valid_code if has_syntax_error: return None diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 6429b5520..77a286812 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -28,6 +28,76 @@ HAS_NUMBA = find_spec("numba") is not None HAS_PYARROW = find_spec("pyarrow") is not None +if HAS_JAX: + try: + import jax # type: ignore # noqa: PGH003 + import jax.numpy as jnp # type: ignore # noqa: PGH003 + except ImportError: + HAS_JAX = False + +if HAS_XARRAY: + try: + import xarray # type: ignore # noqa: PGH003 + except ImportError: + HAS_XARRAY = False + +if HAS_TENSORFLOW: + try: + import tensorflow as tf # type: ignore # noqa: PGH003 + except ImportError: + HAS_TENSORFLOW = False + +if HAS_SQLALCHEMY: + try: + import sqlalchemy # type: ignore # noqa: PGH003 + except ImportError: + HAS_SQLALCHEMY = False + +if HAS_SCIPY: + try: + import scipy # type: ignore # noqa: PGH003 + except ImportError: + HAS_SCIPY = False + +if HAS_NUMPY: + try: + import numpy as np # type: ignore # noqa: PGH003 + except ImportError: + HAS_NUMPY = False + +if HAS_PYARROW: + try: + import pyarrow as pa # type: ignore # noqa: PGH003 + except ImportError: + HAS_PYARROW = False + +if HAS_PANDAS: + try: + import pandas # type: ignore # noqa: ICN001, PGH003 + except ImportError: + HAS_PANDAS = False + +if HAS_TORCH: + try: + import torch # type: ignore # noqa: PGH003 + except ImportError: + HAS_TORCH = False + +if HAS_NUMBA: + try: + import numba # type: ignore # noqa: PGH003 + from numba.core.dispatcher import Dispatcher # type: ignore # noqa: PGH003 + from numba.typed import Dict as NumbaDict # type: ignore # noqa: PGH003 + from numba.typed import List as NumbaList # type: ignore # noqa: PGH003 + except ImportError: + HAS_NUMBA = False + +if HAS_PYRSISTENT: + try: + import pyrsistent # type: ignore # noqa: PGH003 + except ImportError: + HAS_PYRSISTENT = False + # Pattern to match pytest temp directories: /tmp/pytest-of-/pytest-/ # These paths vary between test runs but are logically equivalent PYTEST_TEMP_PATH_PATTERN = re.compile(r"/tmp/pytest-of-[^/]+/pytest-\d+/") # noqa: S108 @@ -185,9 +255,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: return comparator(orig_referent, new_referent, superset_obj) if HAS_JAX: - import jax # type: ignore # noqa: PGH003 - import jax.numpy as jnp # type: ignore # noqa: PGH003 - # Handle JAX arrays first to avoid boolean context errors in other conditions if isinstance(orig, jax.Array): if orig.dtype != new.dtype: @@ -198,15 +265,11 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: # Handle xarray objects before numpy to avoid boolean context errors if HAS_XARRAY: - import xarray # type: ignore # noqa: PGH003 - if isinstance(orig, (xarray.Dataset, xarray.DataArray)): return orig.identical(new) # Handle TensorFlow objects early to avoid boolean context errors if HAS_TENSORFLOW: - import tensorflow as tf # type: ignore # noqa: PGH003 - if isinstance(orig, tf.Tensor): if orig.dtype != new.dtype: return False @@ -243,8 +306,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: return comparator(orig.to_list(), new.to_list(), superset_obj) if HAS_SQLALCHEMY: - import sqlalchemy # type: ignore # noqa: PGH003 - try: insp = sqlalchemy.inspection.inspect(orig) insp = sqlalchemy.inspection.inspect(new) @@ -260,8 +321,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: except sqlalchemy.exc.NoInspectionAvailable: pass - if HAS_SCIPY: - import scipy # type: ignore # noqa: PGH003 # scipy condition because dok_matrix type is also a instance of dict, but dict comparison doesn't work for it if isinstance(orig, dict) and not (HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix)): if superset_obj: @@ -293,8 +352,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: return comparator(dict(orig), dict(new), superset_obj) if HAS_NUMPY: - import numpy as np - if isinstance(orig, (np.datetime64, np.timedelta64)): # Handle NaT (Not a Time) - numpy's equivalent of NaN for datetime if np.isnat(orig) and np.isnat(new): @@ -356,8 +413,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: return (orig != new).nnz == 0 if HAS_PYARROW: - import pyarrow as pa # type: ignore # noqa: PGH003 - if isinstance(orig, pa.Table): if orig.schema != new.schema: return False @@ -400,8 +455,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: return bool(orig.equals(new)) if HAS_PANDAS: - import pandas # noqa: ICN001 - if isinstance( orig, (pandas.DataFrame, pandas.Series, pandas.Index, pandas.Categorical, pandas.arrays.SparseArray) ): @@ -432,8 +485,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: pass if HAS_TORCH: - import torch # type: ignore # noqa: PGH003 - if isinstance(orig, torch.Tensor): if orig.dtype != new.dtype: return False @@ -452,11 +503,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: return orig == new if HAS_NUMBA: - import numba - from numba.core.dispatcher import Dispatcher - from numba.typed import Dict as NumbaDict - from numba.typed import List as NumbaList - # Handle numba typed List if isinstance(orig, NumbaList): if len(orig) != len(new): @@ -488,8 +534,6 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: return orig.py_func is new.py_func if HAS_PYRSISTENT: - import pyrsistent # type: ignore # noqa: PGH003 - if isinstance( orig, ( From a922a29a99cf1a9f174e681e669b6ece40ca6728 Mon Sep 17 00:00:00 2001 From: aseembits93 Date: Tue, 3 Mar 2026 05:15:26 +0530 Subject: [PATCH 2/2] persistent session not needed --- codeflash/api/cfapi.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/codeflash/api/cfapi.py b/codeflash/api/cfapi.py index 900d0f1e8..f7957fa0d 100644 --- a/codeflash/api/cfapi.py +++ b/codeflash/api/cfapi.py @@ -27,8 +27,6 @@ from packaging import version -cfapi_session = requests.Session() - @dataclass class BaseUrls: @@ -79,9 +77,9 @@ def make_cfapi_request( if method.upper() == "POST": json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) cfapi_headers["Content-Type"] = "application/json" - response = cfapi_session.post(url, data=json_payload, headers=cfapi_headers, timeout=60) + response = requests.post(url, data=json_payload, headers=cfapi_headers, timeout=60) else: - response = cfapi_session.get(url, headers=cfapi_headers, params=params, timeout=60) + response = requests.get(url, headers=cfapi_headers, params=params, timeout=60) response.raise_for_status() return response except requests.exceptions.HTTPError: