Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ You can add a default, pickle-based, persistent cache to your function - meaning
"""Your function now has a persistent cache mapped by argument values!"""
return {'arg1': arg1, 'arg2': arg2}

Class and object methods can also be cached. Cachier will automatically ignore the `self` parameter when determining the cache key for an object method. **This means that methods will be cached across all instances of an object, which may not be what you want.**
Class and object methods can also be cached. Cachier will automatically ignore the ``self`` parameter when determining the cache key for an object method. **This means that methods will be cached across all instances of an object, which may not be what you want.** Because this is a common source of bugs, ``@cachier`` raises a ``TypeError`` by default when applied to an instance method (a function whose first parameter is named ``self``). This error is raised when ``@cachier`` is applied (at class definition time), not when the method is called. To opt in to cross-instance cache sharing, pass ``allow_non_static_methods=True``.

.. code-block:: python

Expand All @@ -107,17 +107,18 @@ Class and object methods can also be cached. Cachier will automatically ignore t
return arg_1 + arg_2

# Instance method does not depend on object's internal state, so good to cache
@cachier()
@cachier(allow_non_static_methods=True)
def good_usage_1(self, arg_1, arg_2):
return arg_1 + arg_2

# Instance method is calling external service, probably okay to cache
@cachier()
@cachier(allow_non_static_methods=True)
def good_usage_2(self, arg_1, arg_2):
result = self.call_api(arg_1, arg_2)
return result

# Instance method relies on object attribute, NOT good to cache
# @cachier() would raise TypeError here -- this is intentional
@cachier()
def bad_usage(self, arg_1, arg_2):
return arg_1 + arg_2 + self.arg_3
Expand Down Expand Up @@ -148,6 +149,7 @@ The following parameters will only be applied to decorators defined after `set_d
* `pickle_reload`
* `separate_files`
* `entry_size_limit`
* `allow_non_static_methods`

These parameters can be changed at any time and they will apply to all decorators:

Expand Down
10 changes: 9 additions & 1 deletion src/cachier/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class Params:
cleanup_stale: bool = False
cleanup_interval: timedelta = timedelta(days=1)
entry_size_limit: Optional[int] = None
allow_non_static_methods: bool = False


_global_params = Params()
Expand Down Expand Up @@ -130,7 +131,14 @@ def set_global_params(**params: Any) -> None:
'cleanup_interval', and 'caching_enabled'. In some cores, if the
decorator was created without concrete value for 'wait_for_calc_timeout',
calls that check calculation timeouts will fall back to the global
'wait_for_calc_timeout' as well.
'wait_for_calc_timeout' as well. 'allow_non_static_methods'
(decoration-time only) controls whether instance methods are
permitted; it is read once when @cachier is applied, not on each call.

Note that ``allow_non_static_methods`` is a **decoration-time**
parameter: it is checked once when the ``@cachier`` decorator is
applied and is not re-read on each function call. Changing it via
``set_global_params`` only affects decorators created after the call.

"""
import cachier
Expand Down
27 changes: 25 additions & 2 deletions src/cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def cachier(
cleanup_stale: Optional[bool] = None,
cleanup_interval: Optional[timedelta] = None,
entry_size_limit: Optional[Union[int, str]] = None,
allow_non_static_methods: Optional[bool] = None,
):
"""Wrap as a persistent, stale-free memoization decorator.

Expand Down Expand Up @@ -287,6 +288,13 @@ def cachier(
Maximum serialized size of a cached value. Values exceeding the limit
are returned but not cached. Human readable strings like ``"10MB"`` are
allowed.
allow_non_static_methods : bool, optional
If True, allows ``@cachier`` to decorate instance methods (functions
whose first parameter is named ``self``). By default, decorating an
instance method raises ``TypeError`` because the ``self`` argument is
ignored for cache-key computation, meaning all instances share the
same cache -- which is rarely the intended behaviour. Set this to
``True`` only when cross-instance cache sharing is intentional.

"""
# Check for deprecated parameters
Expand Down Expand Up @@ -356,6 +364,23 @@ def cachier(

def _cachier_decorator(func):
core.set_func(func)

# Guard: raise TypeError when decorating an instance method unless
# explicitly opted in. The 'self' parameter is ignored for cache-key
# computation, so all instances share the same cache.
if core.func_is_method:
_allow_methods = _update_with_defaults(allow_non_static_methods, "allow_non_static_methods")
if not _allow_methods:
raise TypeError(
f"@cachier cannot decorate instance method "
f"'{func.__qualname__}' because the 'self' parameter is "
"excluded from cache-key computation and all instances "
"would share a single cache. Pass allow_non_static_methods=True "
"to the decorator or call "
"set_global_params(allow_non_static_methods=True) if "
"cross-instance cache sharing is intentional."
)

is_coroutine = inspect.iscoroutinefunction(func)

if backend == "mongo":
Expand Down Expand Up @@ -468,7 +493,6 @@ def _call(*args, max_age: Optional[timedelta] = None, **kwds):
_print("max_age is negative. Cached result considered stale.")
nonneg_max_age = False
else:
assert max_age is not None # noqa: S101
max_allowed_age = min(_stale_after, max_age)
# note: if max_age < 0, we always consider a value stale
if nonneg_max_age and (now - entry.time <= max_allowed_age):
Expand Down Expand Up @@ -557,7 +581,6 @@ async def _call_async(*args, max_age: Optional[timedelta] = None, **kwds):
_print("max_age is negative. Cached result considered stale.")
nonneg_max_age = False
else:
assert max_age is not None # noqa: S101
max_allowed_age = min(_stale_after, max_age)
# note: if max_age < 0, we always consider a value stale
if nonneg_max_age and (now - entry.time <= max_allowed_age):
Expand Down
13 changes: 12 additions & 1 deletion src/cachier/cores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,22 @@ def __init__(
self.wait_for_calc_timeout = wait_for_calc_timeout
self.lock = threading.RLock()
self.entry_size_limit = entry_size_limit
self.func_is_method: bool = False

def set_func(self, func):
"""Set the function this core will use.

This has to be set before any method is called. Also determine if the function is an object method.
This must be called before any other method is invoked. In addition
to storing ``func`` on the instance, this method inspects the
function's signature and sets ``self.func_is_method`` to ``True``
when the first parameter is named ``"self"``.

Notes
-----
Detection is name-based: only ``func_params[0] == "self"`` is
checked. ``@classmethod`` functions whose first parameter is
conventionally named ``cls`` are not detected as methods --
this is a known gap.

"""
# unwrap if the function is functools.partial
Expand Down
2 changes: 1 addition & 1 deletion tests/mongo_tests/test_async_mongo_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def async_mongetter():
call_count = 0

class _MongoMethods:
@cachier(mongetter=async_mongetter)
@cachier(mongetter=async_mongetter, allow_non_static_methods=True)
async def async_cached_mongo_method_args_kwargs(self, x: int, y: int) -> int:
nonlocal call_count
call_count += 1
Expand Down
2 changes: 1 addition & 1 deletion tests/redis_tests/test_async_redis_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def get_redis_client():
call_count = 0

class _RedisMethods:
@cachier(backend="redis", redis_client=get_redis_client)
@cachier(backend="redis", redis_client=get_redis_client, allow_non_static_methods=True)
async def async_cached_redis_method_args_kwargs(self, x: int, y: int) -> int:
nonlocal call_count
call_count += 1
Expand Down
17 changes: 15 additions & 2 deletions tests/test_async_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,9 @@ class MyClass:
def __init__(self, value):
self.value = value

@cachier(backend="memory")
# allow_non_static_methods=True: cross-instance cache sharing
# is intentional in this test
@cachier(backend="memory", allow_non_static_methods=True)
async def async_method(self, x):
await asyncio.sleep(0.1)
return x * self.value
Expand Down Expand Up @@ -290,7 +292,9 @@ class MyClass:
def __init__(self, value):
self.value = value

@cachier(backend="memory")
# allow_non_static_methods=True: cross-instance cache sharing
# is intentional in this test
@cachier(backend="memory", allow_non_static_methods=True)
async def async_method(self, x):
await asyncio.sleep(0.1)
return x * self.value
Expand All @@ -311,6 +315,15 @@ async def async_method(self, x):

obj1.async_method.clear_cache()

async def test_guard_raises_without_opt_in(self):
"""Test that @cachier raises TypeError for async instance methods without opt-in."""
with pytest.raises(TypeError, match="allow_non_static_methods"):

class MyClass:
@cachier(backend="memory")
async def async_method(self, x):
return x


# =============================================================================
# Sync Function Compatibility Tests
Expand Down
2 changes: 1 addition & 1 deletion tests/test_caching_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(self, cache_ttl=None):
cachier.enable_caching()

# Use memory backend to avoid file cache persistence issues
@cachier.cachier(backend="memory")
@cachier.cachier(backend="memory", allow_non_static_methods=True)
def test(self, param):
self.counter += 1
return param
Expand Down
11 changes: 7 additions & 4 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

import pytest

from cachier.config import get_default_params, set_default_params
from cachier.config import get_default_params, get_global_params, set_default_params, set_global_params


def test_set_default_params_deprecated():
"""Test that set_default_params shows deprecation warning."""
# Test lines 103-111: deprecation warning
with pytest.warns(DeprecationWarning, match="set_default_params.*deprecated.*set_global_params"):
set_default_params(stale_after=60)
original = get_global_params().stale_after
try:
with pytest.warns(DeprecationWarning, match="set_default_params.*deprecated.*set_global_params"):
set_default_params(stale_after=60)
finally:
set_global_params(stale_after=original)


def test_get_default_params_deprecated():
Expand Down
1 change: 1 addition & 0 deletions tests/test_core_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
def test_get_default_params():
params = get_global_params()
assert sorted(vars(params).keys()) == [
"allow_non_static_methods",
"allow_none",
"backend",
"cache_dir",
Expand Down
10 changes: 5 additions & 5 deletions tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def dummy_func(arg_1, arg_2):
)
def test_ignore_self_in_methods(mongetter, backend):
class DummyClass:
@cachier.cachier(backend=backend, mongetter=mongetter)
@cachier.cachier(backend=backend, mongetter=mongetter, allow_non_static_methods=True)
def takes_2_seconds(self, arg_1, arg_2):
"""Some function."""
sleep(2)
Expand Down Expand Up @@ -257,7 +257,7 @@ def test():

def test_global_disable_method():
class Test:
@cachier.cachier()
@cachier.cachier(allow_non_static_methods=True)
def test(self):
return True

Expand All @@ -270,7 +270,7 @@ def test(self):

def test_global_disable_method_with_args():
class Test:
@cachier.cachier()
@cachier.cachier(allow_non_static_methods=True)
def test(self, test):
return test

Expand All @@ -286,7 +286,7 @@ class Test:
def __init__(self, val):
self.val = val

@cachier.cachier()
@cachier.cachier(allow_non_static_methods=True)
def test(self, test=0):
return self.val + test

Expand All @@ -302,7 +302,7 @@ class Test:
def __init__(self, val):
self.val = val

@cachier.cachier()
@cachier.cachier(allow_non_static_methods=True)
def test(self, test1, test2=0):
return self.val + test1 + test2

Expand Down
69 changes: 69 additions & 0 deletions tests/test_smoke.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Smoke tests for cachier - fast, no external service dependencies."""

import datetime
import hashlib
import pickle

import pytest

Expand Down Expand Up @@ -168,3 +170,70 @@ def func():
return 1

assert func.cache_dpath() is None


@pytest.mark.smoke
def test_classmethod_not_guarded():
"""Test that @classmethod (cls) does not trigger the instance method guard.

Note: Place ``@classmethod`` above ``@cachier`` so that ``@cachier``
decorates the underlying function before it is wrapped as a classmethod.
This way cachier sees ``cls`` (not ``self``) as the first parameter and
the instance-method guard is not triggered.

"""

# A custom hash_func is needed because the default pickle-based
# hash function cannot serialise the local class object passed as
# ``cls``.
def _hash_ignore_cls(args, kwds):
filtered = {k: v for k, v in kwds.items() if k != "cls"}
return hashlib.sha256(pickle.dumps((args, sorted(filtered.items())))).hexdigest()

class Foo:
@classmethod
@cachier_decorator(backend="memory", hash_func=_hash_ignore_cls)
def method(cls, x):
return x + 1

Foo.method.clear_cache()
assert Foo.method(2) == 3
Foo.method.clear_cache()


@pytest.mark.smoke
def test_instance_method_global_opt_out_reset():
"""Test that resetting allow_non_static_methods=False re-enables the guard."""
original = get_global_params().allow_non_static_methods
try:
set_global_params(allow_non_static_methods=True)
set_global_params(allow_non_static_methods=False)
with pytest.raises(TypeError, match="instance method"):

class Foo:
@cachier_decorator(backend="memory")
def method(self, x):
return x
finally:
set_global_params(allow_non_static_methods=original)


@pytest.mark.smoke
def test_instance_method_skip_cache():
"""Test that cachier__skip_cache=True works for methods with allow_non_static_methods."""
call_count = 0

class Foo:
@cachier_decorator(backend="memory", allow_non_static_methods=True)
def method(self, x):
nonlocal call_count
call_count += 1
return x * 2

obj = Foo()
obj.method.clear_cache()
assert obj.method(5) == 10
assert call_count == 1
assert obj.method(5, cachier__skip_cache=True) == 10
assert call_count == 2 # recalculated, not from cache
obj.method.clear_cache()
Loading