diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 62c53494..b9d96b1c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -10,6 +10,7 @@ Changes * ENH: Added CLI argument ``-m`` to ``kernprof`` for running a library module as a script; also made it possible for profiling targets to be supplied across multiple ``-p`` flags * FIX: Fixed explicit profiling of class methods; added handling for profiling static, bound, and partial methods, ``functools.partial`` objects, (cached) properties, and async generator functions * FIX: Fixed namespace bug when running ``kernprof -m`` on certain modules (e.g. ``calendar`` on Python 3.12+). +* FIX: Fixed ``@contextlib.contextmanager`` bug where the cleanup code (e.g. restoration of ``sys`` attributes) is not run if exceptions occurred inside the context 4.2.0 ~~~~~ diff --git a/kernprof.py b/kernprof.py index b8e558cf..52867331 100755 --- a/kernprof.py +++ b/kernprof.py @@ -80,7 +80,6 @@ def main(): --prof-imports If specified, modules specified to `--prof-mod` will also autoprofile modules that they import. Only works with line_profiler -l, --line-by-line """ import builtins -import contextlib import functools import os import sys @@ -224,8 +223,7 @@ def _python_command(): return sys.executable -@contextlib.contextmanager -def _restore_list(lst): +class _restore_list: """ Restore a list like `sys.path` after running code which potentially modifies it. @@ -248,9 +246,24 @@ def _restore_list(lst): >>> l [1, 2, 3] """ - old = lst.copy() - yield - lst[:] = old + def __init__(self, lst): + self.lst = lst + self.old = None + + def __enter__(self): + assert self.old is None + self.old = self.lst.copy() + + def __exit__(self, *_, **__): + self.old, self.lst[:] = None, self.old + + def __call__(self, func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with self: + return func(*args, **kwargs) + + return wrapper def pre_parse_single_arg_directive(args, flag, sep='--'): diff --git a/line_profiler/autoprofile/autoprofile.py b/line_profiler/autoprofile/autoprofile.py index d5894318..5985a84b 100644 --- a/line_profiler/autoprofile/autoprofile.py +++ b/line_profiler/autoprofile/autoprofile.py @@ -97,12 +97,21 @@ def run(script_file, ns, prof_mod, profile_imports=False, as_module=False): as_module (bool): Whether we're running script_file as a module """ - @contextlib.contextmanager - def restore_dict(d, target=None): - copy = d.copy() - yield target - d.clear() - d.update(copy) + class restore_dict: + def __init__(self, d, target=None): + self.d = d + self.target = target + self.copy = None + + def __enter__(self): + assert self.copy is None + self.copy = self.d.copy() + return self.target + + def __exit__(self, *_, **__): + self.d.clear() + self.d.update(self.copy) + self.copy = None if as_module: Profiler = AstTreeModuleProfiler diff --git a/tests/test_kernprof.py b/tests/test_kernprof.py index cfb7a553..c56dd215 100644 --- a/tests/test_kernprof.py +++ b/tests/test_kernprof.py @@ -1,3 +1,4 @@ +import contextlib import os import re import shlex @@ -8,7 +9,7 @@ import pytest import ubelt as ub -from kernprof import ContextualProfile +from kernprof import main, ContextualProfile def f(x): @@ -123,6 +124,59 @@ def main(): assert ('Function: main' in proc.stdout) == profiled_main +@pytest.mark.parametrize('error', [True, False]) +@pytest.mark.parametrize( + 'args', + ['', '-pmymod'], # Normal execution / auto-profile +) +def test_kernprof_sys_restoration(capsys, error, args): + """ + Test that `kernprof.main()` and + `line_profiler.autoprofile.autoprofile.run()` (resp.) properly + restores `sys.path` (resp. `sys.modules['__main__']`) on the way + out. + + Notes + ----- + The test is run in-process. + """ + with contextlib.ExitStack() as stack: + enter = stack.enter_context + tmpdir = enter(tempfile.TemporaryDirectory()) + assert tmpdir not in sys.path + temp_dpath = ub.Path(tmpdir) + (temp_dpath / 'mymod.py').write_text(ub.codeblock( + f''' + import sys + + + def main(): + # Mess up `sys.path` + sys.path.append({tmpdir!r}) + # Output + print(1) + # Optionally raise an error + if {error!r}: + raise Exception + + + if __name__ == '__main__': + main() + ''')) + enter(ub.ChDir(tmpdir)) + if error: + ctx = pytest.raises(BaseException) + else: + ctx = contextlib.nullcontext() + old_main = sys.modules.get('__main__') + with ctx: + main(['-l', *shlex.split(args), '-m', 'mymod']) + out, _ = capsys.readouterr() + assert out.startswith('1') + assert tmpdir not in sys.path + assert sys.modules.get('__main__') is old_main + + class TestKernprof(unittest.TestCase): def test_enable_disable(self): diff --git a/tests/test_line_profiler.py b/tests/test_line_profiler.py index b2a12028..85860b2a 100644 --- a/tests/test_line_profiler.py +++ b/tests/test_line_profiler.py @@ -41,18 +41,28 @@ def strip(s): return textwrap.dedent(s).strip('\n') -@contextlib.contextmanager -def check_timings(prof): +class check_timings: """ Verify that the profiler starts without timing data and ends with some. """ - timings = prof.get_stats().timings - assert not any(timings.values()), ('Expected no timing entries, ' - f'got {timings!r}') - yield prof - timings = prof.get_stats().timings - assert any(timings.values()), f'Expected timing entries, got {timings!r}' + def __init__(self, prof): + self.prof = prof + + def __enter__(self): + timings = self.timings + assert not any(timings.values()), ( + f'Expected no timing entries, got {timings!r}') + return self.prof + + def __exit__(self, *_, **__): + timings = self.timings + assert any(timings.values()), ( + f'Expected timing entries, got {timings!r}') + + @property + def timings(self): + return self.prof.get_stats().timings def test_init():