diff --git a/README.rst b/README.rst index fc6c48ca..374f6db2 100644 --- a/README.rst +++ b/README.rst @@ -94,7 +94,7 @@ You can add a default, pickle-based, persistent cache to your function - meaning """Your function now has a persistent cache mapped by argument values!""" return {'arg1': arg1, 'arg2': arg2} -Class and object methods can also be cached. Cachier will automatically ignore the `self` parameter when determining the cache key for an object method. **This means that methods will be cached across all instances of an object, which may not be what you want.** +Class and object methods can also be cached. Cachier will automatically ignore the ``self`` parameter when determining the cache key for an object method. **This means that methods will be cached across all instances of an object, which may not be what you want.** Because this is a common source of bugs, ``@cachier`` raises a ``TypeError`` by default when applied to an instance method (a function whose first parameter is named ``self``). This error is raised when ``@cachier`` is applied (at class definition time), not when the method is called. To opt in to cross-instance cache sharing, pass ``allow_non_static_methods=True``. .. code-block:: python @@ -107,17 +107,18 @@ Class and object methods can also be cached. Cachier will automatically ignore t return arg_1 + arg_2 # Instance method does not depend on object's internal state, so good to cache - @cachier() + @cachier(allow_non_static_methods=True) def good_usage_1(self, arg_1, arg_2): return arg_1 + arg_2 # Instance method is calling external service, probably okay to cache - @cachier() + @cachier(allow_non_static_methods=True) def good_usage_2(self, arg_1, arg_2): result = self.call_api(arg_1, arg_2) return result # Instance method relies on object attribute, NOT good to cache + # @cachier() would raise TypeError here -- this is intentional @cachier() def bad_usage(self, arg_1, arg_2): return arg_1 + arg_2 + self.arg_3 @@ -148,6 +149,7 @@ The following parameters will only be applied to decorators defined after `set_d * `pickle_reload` * `separate_files` * `entry_size_limit` +* `allow_non_static_methods` These parameters can be changed at any time and they will apply to all decorators: diff --git a/src/cachier/config.py b/src/cachier/config.py index f04d0cff..6b1652e4 100644 --- a/src/cachier/config.py +++ b/src/cachier/config.py @@ -66,6 +66,7 @@ class Params: cleanup_stale: bool = False cleanup_interval: timedelta = timedelta(days=1) entry_size_limit: Optional[int] = None + allow_non_static_methods: bool = False _global_params = Params() @@ -130,7 +131,14 @@ def set_global_params(**params: Any) -> None: 'cleanup_interval', and 'caching_enabled'. In some cores, if the decorator was created without concrete value for 'wait_for_calc_timeout', calls that check calculation timeouts will fall back to the global - 'wait_for_calc_timeout' as well. + 'wait_for_calc_timeout' as well. 'allow_non_static_methods' + (decoration-time only) controls whether instance methods are + permitted; it is read once when @cachier is applied, not on each call. + + Note that ``allow_non_static_methods`` is a **decoration-time** + parameter: it is checked once when the ``@cachier`` decorator is + applied and is not re-read on each function call. Changing it via + ``set_global_params`` only affects decorators created after the call. """ import cachier diff --git a/src/cachier/core.py b/src/cachier/core.py index 9031e1fa..d4007b6b 100644 --- a/src/cachier/core.py +++ b/src/cachier/core.py @@ -200,6 +200,7 @@ def cachier( cleanup_stale: Optional[bool] = None, cleanup_interval: Optional[timedelta] = None, entry_size_limit: Optional[Union[int, str]] = None, + allow_non_static_methods: Optional[bool] = None, ): """Wrap as a persistent, stale-free memoization decorator. @@ -287,6 +288,13 @@ def cachier( Maximum serialized size of a cached value. Values exceeding the limit are returned but not cached. Human readable strings like ``"10MB"`` are allowed. + allow_non_static_methods : bool, optional + If True, allows ``@cachier`` to decorate instance methods (functions + whose first parameter is named ``self``). By default, decorating an + instance method raises ``TypeError`` because the ``self`` argument is + ignored for cache-key computation, meaning all instances share the + same cache -- which is rarely the intended behaviour. Set this to + ``True`` only when cross-instance cache sharing is intentional. """ # Check for deprecated parameters @@ -356,6 +364,23 @@ def cachier( def _cachier_decorator(func): core.set_func(func) + + # Guard: raise TypeError when decorating an instance method unless + # explicitly opted in. The 'self' parameter is ignored for cache-key + # computation, so all instances share the same cache. + if core.func_is_method: + _allow_methods = _update_with_defaults(allow_non_static_methods, "allow_non_static_methods") + if not _allow_methods: + raise TypeError( + f"@cachier cannot decorate instance method " + f"'{func.__qualname__}' because the 'self' parameter is " + "excluded from cache-key computation and all instances " + "would share a single cache. Pass allow_non_static_methods=True " + "to the decorator or call " + "set_global_params(allow_non_static_methods=True) if " + "cross-instance cache sharing is intentional." + ) + is_coroutine = inspect.iscoroutinefunction(func) if backend == "mongo": @@ -468,7 +493,6 @@ def _call(*args, max_age: Optional[timedelta] = None, **kwds): _print("max_age is negative. Cached result considered stale.") nonneg_max_age = False else: - assert max_age is not None # noqa: S101 max_allowed_age = min(_stale_after, max_age) # note: if max_age < 0, we always consider a value stale if nonneg_max_age and (now - entry.time <= max_allowed_age): @@ -557,7 +581,6 @@ async def _call_async(*args, max_age: Optional[timedelta] = None, **kwds): _print("max_age is negative. Cached result considered stale.") nonneg_max_age = False else: - assert max_age is not None # noqa: S101 max_allowed_age = min(_stale_after, max_age) # note: if max_age < 0, we always consider a value stale if nonneg_max_age and (now - entry.time <= max_allowed_age): diff --git a/src/cachier/cores/base.py b/src/cachier/cores/base.py index 547382da..1e099458 100644 --- a/src/cachier/cores/base.py +++ b/src/cachier/cores/base.py @@ -48,11 +48,22 @@ def __init__( self.wait_for_calc_timeout = wait_for_calc_timeout self.lock = threading.RLock() self.entry_size_limit = entry_size_limit + self.func_is_method: bool = False def set_func(self, func): """Set the function this core will use. - This has to be set before any method is called. Also determine if the function is an object method. + This must be called before any other method is invoked. In addition + to storing ``func`` on the instance, this method inspects the + function's signature and sets ``self.func_is_method`` to ``True`` + when the first parameter is named ``"self"``. + + Notes + ----- + Detection is name-based: only ``func_params[0] == "self"`` is + checked. ``@classmethod`` functions whose first parameter is + conventionally named ``cls`` are not detected as methods -- + this is a known gap. """ # unwrap if the function is functools.partial diff --git a/tests/mongo_tests/test_async_mongo_core.py b/tests/mongo_tests/test_async_mongo_core.py index bcb9a233..d92226bd 100644 --- a/tests/mongo_tests/test_async_mongo_core.py +++ b/tests/mongo_tests/test_async_mongo_core.py @@ -76,7 +76,7 @@ async def async_mongetter(): call_count = 0 class _MongoMethods: - @cachier(mongetter=async_mongetter) + @cachier(mongetter=async_mongetter, allow_non_static_methods=True) async def async_cached_mongo_method_args_kwargs(self, x: int, y: int) -> int: nonlocal call_count call_count += 1 diff --git a/tests/redis_tests/test_async_redis_core.py b/tests/redis_tests/test_async_redis_core.py index 4d730d6b..e5dec8f7 100644 --- a/tests/redis_tests/test_async_redis_core.py +++ b/tests/redis_tests/test_async_redis_core.py @@ -84,7 +84,7 @@ async def get_redis_client(): call_count = 0 class _RedisMethods: - @cachier(backend="redis", redis_client=get_redis_client) + @cachier(backend="redis", redis_client=get_redis_client, allow_non_static_methods=True) async def async_cached_redis_method_args_kwargs(self, x: int, y: int) -> int: nonlocal call_count call_count += 1 diff --git a/tests/test_async_core.py b/tests/test_async_core.py index 5adeedb9..770c0ee4 100644 --- a/tests/test_async_core.py +++ b/tests/test_async_core.py @@ -261,7 +261,9 @@ class MyClass: def __init__(self, value): self.value = value - @cachier(backend="memory") + # allow_non_static_methods=True: cross-instance cache sharing + # is intentional in this test + @cachier(backend="memory", allow_non_static_methods=True) async def async_method(self, x): await asyncio.sleep(0.1) return x * self.value @@ -290,7 +292,9 @@ class MyClass: def __init__(self, value): self.value = value - @cachier(backend="memory") + # allow_non_static_methods=True: cross-instance cache sharing + # is intentional in this test + @cachier(backend="memory", allow_non_static_methods=True) async def async_method(self, x): await asyncio.sleep(0.1) return x * self.value @@ -311,6 +315,15 @@ async def async_method(self, x): obj1.async_method.clear_cache() + async def test_guard_raises_without_opt_in(self): + """Test that @cachier raises TypeError for async instance methods without opt-in.""" + with pytest.raises(TypeError, match="allow_non_static_methods"): + + class MyClass: + @cachier(backend="memory") + async def async_method(self, x): + return x + # ============================================================================= # Sync Function Compatibility Tests diff --git a/tests/test_caching_regression.py b/tests/test_caching_regression.py index 82f857cb..8b018885 100644 --- a/tests/test_caching_regression.py +++ b/tests/test_caching_regression.py @@ -105,7 +105,7 @@ def __init__(self, cache_ttl=None): cachier.enable_caching() # Use memory backend to avoid file cache persistence issues - @cachier.cachier(backend="memory") + @cachier.cachier(backend="memory", allow_non_static_methods=True) def test(self, param): self.counter += 1 return param diff --git a/tests/test_config.py b/tests/test_config.py index 19e7f528..ccd0f2a0 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,14 +2,17 @@ import pytest -from cachier.config import get_default_params, set_default_params +from cachier.config import get_default_params, get_global_params, set_default_params, set_global_params def test_set_default_params_deprecated(): """Test that set_default_params shows deprecation warning.""" - # Test lines 103-111: deprecation warning - with pytest.warns(DeprecationWarning, match="set_default_params.*deprecated.*set_global_params"): - set_default_params(stale_after=60) + original = get_global_params().stale_after + try: + with pytest.warns(DeprecationWarning, match="set_default_params.*deprecated.*set_global_params"): + set_default_params(stale_after=60) + finally: + set_global_params(stale_after=original) def test_get_default_params_deprecated(): diff --git a/tests/test_core_lookup.py b/tests/test_core_lookup.py index 10675b6c..599aa194 100644 --- a/tests/test_core_lookup.py +++ b/tests/test_core_lookup.py @@ -8,6 +8,7 @@ def test_get_default_params(): params = get_global_params() assert sorted(vars(params).keys()) == [ + "allow_non_static_methods", "allow_none", "backend", "cache_dir", diff --git a/tests/test_general.py b/tests/test_general.py index 280709e4..1592c9fb 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -185,7 +185,7 @@ def dummy_func(arg_1, arg_2): ) def test_ignore_self_in_methods(mongetter, backend): class DummyClass: - @cachier.cachier(backend=backend, mongetter=mongetter) + @cachier.cachier(backend=backend, mongetter=mongetter, allow_non_static_methods=True) def takes_2_seconds(self, arg_1, arg_2): """Some function.""" sleep(2) @@ -257,7 +257,7 @@ def test(): def test_global_disable_method(): class Test: - @cachier.cachier() + @cachier.cachier(allow_non_static_methods=True) def test(self): return True @@ -270,7 +270,7 @@ def test(self): def test_global_disable_method_with_args(): class Test: - @cachier.cachier() + @cachier.cachier(allow_non_static_methods=True) def test(self, test): return test @@ -286,7 +286,7 @@ class Test: def __init__(self, val): self.val = val - @cachier.cachier() + @cachier.cachier(allow_non_static_methods=True) def test(self, test=0): return self.val + test @@ -302,7 +302,7 @@ class Test: def __init__(self, val): self.val = val - @cachier.cachier() + @cachier.cachier(allow_non_static_methods=True) def test(self, test1, test2=0): return self.val + test1 + test2 diff --git a/tests/test_smoke.py b/tests/test_smoke.py index 31f2bd2f..1f38afd3 100644 --- a/tests/test_smoke.py +++ b/tests/test_smoke.py @@ -1,6 +1,8 @@ """Smoke tests for cachier - fast, no external service dependencies.""" import datetime +import hashlib +import pickle import pytest @@ -168,3 +170,70 @@ def func(): return 1 assert func.cache_dpath() is None + + +@pytest.mark.smoke +def test_classmethod_not_guarded(): + """Test that @classmethod (cls) does not trigger the instance method guard. + + Note: Place ``@classmethod`` above ``@cachier`` so that ``@cachier`` + decorates the underlying function before it is wrapped as a classmethod. + This way cachier sees ``cls`` (not ``self``) as the first parameter and + the instance-method guard is not triggered. + + """ + + # A custom hash_func is needed because the default pickle-based + # hash function cannot serialise the local class object passed as + # ``cls``. + def _hash_ignore_cls(args, kwds): + filtered = {k: v for k, v in kwds.items() if k != "cls"} + return hashlib.sha256(pickle.dumps((args, sorted(filtered.items())))).hexdigest() + + class Foo: + @classmethod + @cachier_decorator(backend="memory", hash_func=_hash_ignore_cls) + def method(cls, x): + return x + 1 + + Foo.method.clear_cache() + assert Foo.method(2) == 3 + Foo.method.clear_cache() + + +@pytest.mark.smoke +def test_instance_method_global_opt_out_reset(): + """Test that resetting allow_non_static_methods=False re-enables the guard.""" + original = get_global_params().allow_non_static_methods + try: + set_global_params(allow_non_static_methods=True) + set_global_params(allow_non_static_methods=False) + with pytest.raises(TypeError, match="instance method"): + + class Foo: + @cachier_decorator(backend="memory") + def method(self, x): + return x + finally: + set_global_params(allow_non_static_methods=original) + + +@pytest.mark.smoke +def test_instance_method_skip_cache(): + """Test that cachier__skip_cache=True works for methods with allow_non_static_methods.""" + call_count = 0 + + class Foo: + @cachier_decorator(backend="memory", allow_non_static_methods=True) + def method(self, x): + nonlocal call_count + call_count += 1 + return x * 2 + + obj = Foo() + obj.method.clear_cache() + assert obj.method(5) == 10 + assert call_count == 1 + assert obj.method(5, cachier__skip_cache=True) == 10 + assert call_count == 2 # recalculated, not from cache + obj.method.clear_cache()