diff --git a/aws_advanced_python_wrapper/utils/concurrent.py b/aws_advanced_python_wrapper/utils/concurrent.py index 04836932..679933a0 100644 --- a/aws_advanced_python_wrapper/utils/concurrent.py +++ b/aws_advanced_python_wrapper/utils/concurrent.py @@ -14,13 +14,9 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, Iterator, Set, Union, ValuesView - -if TYPE_CHECKING: - from typing import ItemsView - from threading import Condition, Lock, RLock -from typing import Callable, Generic, KeysView, List, Optional, TypeVar +from typing import (Callable, Dict, Generic, Iterator, List, Optional, Set, + TypeVar, Union) K = TypeVar('K') V = TypeVar('V') @@ -111,14 +107,20 @@ def apply_if(self, predicate: Callable, apply: Callable): if predicate(key, value): apply(key, value) - def keys(self) -> KeysView: - return self._dict.keys() + def keys(self) -> List[K]: + """Returns a thread-safe snapshot of keys.""" + with self._lock: + return list(self._dict.keys()) - def values(self) -> ValuesView: - return self._dict.values() + def values(self) -> List[V]: + """Returns a thread-safe snapshot of values.""" + with self._lock: + return list(self._dict.values()) - def items(self) -> ItemsView: - return self._dict.items() + def items(self) -> List[tuple[K, V]]: + """Returns a thread-safe snapshot of items.""" + with self._lock: + return list(self._dict.items()) class ConcurrentSet(Generic[V]): diff --git a/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py b/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py index 8dd9c219..4085e43c 100644 --- a/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py +++ b/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py @@ -16,7 +16,7 @@ from threading import Thread from time import perf_counter_ns, sleep -from typing import Callable, Generic, ItemsView, KeysView, Optional, TypeVar +from typing import Callable, Generic, List, Optional, Tuple, TypeVar from aws_advanced_python_wrapper.utils.atomic import AtomicInt from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict @@ -46,10 +46,10 @@ def __len__(self): def set_cleanup_interval_ns(self, interval_ns): self._cleanup_interval_ns = interval_ns - def keys(self) -> KeysView: + def keys(self) -> List[K]: return self._cdict.keys() - def items(self) -> ItemsView: + def items(self) -> List[Tuple[K, CacheItem[V]]]: return self._cdict.items() def compute_if_absent(self, key: K, mapping_func: Callable, item_expiration_ns: int) -> Optional[V]: @@ -73,32 +73,28 @@ def _remove_and_dispose(self, key: K): self._item_disposal_func(cache_item.item) def _remove_if_expired(self, key: K): - item = None - def _remove_if_expired_internal(_, cache_item): if self._should_cleanup_item(cache_item): - nonlocal item - item = cache_item.item + # Dispose while holding the lock to prevent race conditions + if self._item_disposal_func is not None: + self._item_disposal_func(cache_item.item) return None - return cache_item self._cdict.compute_if_present(key, _remove_if_expired_internal) - if item is None or self._item_disposal_func is None: - return - - self._item_disposal_func(item) - def _should_cleanup_item(self, cache_item: CacheItem) -> bool: if self._should_dispose_func is not None: return perf_counter_ns() > cache_item.expiration_time and self._should_dispose_func(cache_item.item) return perf_counter_ns() > cache_item.expiration_time def clear(self): - for _, cache_item in self._cdict.items(): - if cache_item is not None and self._item_disposal_func is not None: - self._item_disposal_func(cache_item.item) + # Dispose all items while holding the lock + if self._item_disposal_func is not None: + self._cdict.apply_if( + lambda k, v: True, # Apply to all items + lambda k, cache_item: self._item_disposal_func(cache_item.item) + ) self._cdict.clear() def _cleanup(self): @@ -107,7 +103,7 @@ def _cleanup(self): return self._cleanup_time_ns.set(current_time + self._cleanup_interval_ns) - keys = [key for key, _ in self._cdict.items()] + keys = self._cdict.keys() for key in keys: self._remove_if_expired(key) @@ -129,29 +125,21 @@ def compute_if_absent_with_disposal(self, key: K, mapping_func: Callable, item_e return None if cache_item is None else cache_item.update_expiration(item_expiration_ns).item def _remove_if_disposable(self, key: K): - item = None - def _remove_if_disposable_internal(_, cache_item): if self._should_dispose_func is not None and self._should_dispose_func(cache_item.item): - nonlocal item - item = cache_item.item + if self._item_disposal_func is not None: + self._item_disposal_func(cache_item.item) return None - return cache_item self._cdict.compute_if_present(key, _remove_if_disposable_internal) - if item is None or self._item_disposal_func is None: - return - - self._item_disposal_func(item) - def _cleanup_thread_internal(self): while True: try: sleep(self._cleanup_interval_ns / 1_000_000_000) self._cleanup_time_ns.set(perf_counter_ns() + self._cleanup_interval_ns) - keys = [key for key, _ in self._cdict.items()] + keys = self._cdict.keys() for key in keys: try: self._remove_if_expired(key) diff --git a/tests/unit/test_sql_alchemy_pooled_connection_provider.py b/tests/unit/test_sql_alchemy_pooled_connection_provider.py index a23ca903..ba36c08c 100644 --- a/tests/unit/test_sql_alchemy_pooled_connection_provider.py +++ b/tests/unit/test_sql_alchemy_pooled_connection_provider.py @@ -65,7 +65,7 @@ def clear_cache(): def test_connect__default_mapping__default_pool_configuration(provider, host_info, mocker, mock_conn, mock_pool): expected_urls = {host_info.url} - expected_keys = {PoolKey(host_info.url, "user1")} + expected_keys = [PoolKey(host_info.url, "user1")] props = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "password"}) conn = provider.connect(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock(), host_info, props) @@ -76,7 +76,7 @@ def test_connect__default_mapping__default_pool_configuration(provider, host_inf def test_connect__custom_configuration_and_mapping(host_info, mocker, mock_conn, mock_pool): - expected_keys = {PoolKey(host_info.url, f"{host_info.url}+some_unique_key")} + expected_keys = [PoolKey(host_info.url, f"{host_info.url}+some_unique_key")] props = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "password"}) attempt_creator_override_func = mocker.MagicMock()