diff --git a/mocket/decorators/mocketizer.py b/mocket/decorators/mocketizer.py index fb7c811..4174532 100644 --- a/mocket/decorators/mocketizer.py +++ b/mocket/decorators/mocketizer.py @@ -1,3 +1,8 @@ +from __future__ import annotations + +import functools +import inspect + from mocket.mocket import Mocket from mocket.mode import MocketMode from mocket.utils import get_mocketize @@ -60,7 +65,7 @@ def check_and_call(self, method_name): def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args): instance = args[0] if args else None namespace = None - if truesocket_recording_dir: + if truesocket_recording_dir and instance: namespace = ".".join( ( instance.__class__.__module__, @@ -78,7 +83,7 @@ def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, ar ) -def wrapper( +def _function_wrapper( test, truesocket_recording_dir=None, strict_mode=False, @@ -92,4 +97,54 @@ def wrapper( return test(*args, **kwargs) -mocketize = get_mocketize(wrapper) +_function_mocketize = get_mocketize(_function_wrapper) + + +def _class_decorator_factory(**options): + def decorator(cls): + orig_setup = getattr(cls, "setUp", lambda self, *a, **kw: None) + orig_td = getattr(cls, "tearDown", lambda self, *a, **kw: None) + use_add_cleanup = hasattr(cls, "addCleanup") + + def setUp(self, *a, **kw): + ctx = Mocketizer(instance=self, **options) + ctx.enter() + if use_add_cleanup: + self.addCleanup(ctx.exit) + else: + self.__mocket_ctx = ctx + orig_setup(self, *a, **kw) + + cls.setUp = functools.wraps(orig_setup)(setUp) + + if not use_add_cleanup: + + def tearDown(self, *a, **kw): + try: + orig_td(self, *a, **kw) + finally: + if hasattr(self, "__mocket_ctx"): + self.__mocket_ctx.exit() + + cls.tearDown = functools.wraps(orig_td)(tearDown) + + return cls + + return decorator + + +def mocketize(*dargs, **dkwargs): + # bare @mocketize + if dargs and len(dargs) == 1 and callable(dargs[0]) and not dkwargs: + target = dargs[0] + if inspect.isclass(target): + return _class_decorator_factory()(target) + return _function_mocketize(target) + + # @mocketize(...) + def real_decorator(target): + if inspect.isclass(target): + return _class_decorator_factory(**dkwargs)(target) + return _function_mocketize(**dkwargs)(target) + + return real_decorator