Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions aws_advanced_python_wrapper/utils/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,9 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Iterator, Set, Union, ValuesView

if TYPE_CHECKING:
from typing import ItemsView

from threading import Condition, Lock, RLock
from typing import Callable, Generic, KeysView, List, Optional, TypeVar
from typing import (Callable, Dict, Generic, Iterator, List, Optional, Set,
TypeVar, Union)

K = TypeVar('K')
V = TypeVar('V')
Expand Down Expand Up @@ -111,14 +107,20 @@ def apply_if(self, predicate: Callable, apply: Callable):
if predicate(key, value):
apply(key, value)

def keys(self) -> KeysView:
return self._dict.keys()
def keys(self) -> List[K]:
"""Returns a thread-safe snapshot of keys."""
with self._lock:
return list(self._dict.keys())

def values(self) -> ValuesView:
return self._dict.values()
def values(self) -> List[V]:
"""Returns a thread-safe snapshot of values."""
with self._lock:
return list(self._dict.values())

def items(self) -> ItemsView:
return self._dict.items()
def items(self) -> List[tuple[K, V]]:
"""Returns a thread-safe snapshot of items."""
with self._lock:
return list(self._dict.items())


class ConcurrentSet(Generic[V]):
Expand Down
44 changes: 16 additions & 28 deletions aws_advanced_python_wrapper/utils/sliding_expiration_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from threading import Thread
from time import perf_counter_ns, sleep
from typing import Callable, Generic, ItemsView, KeysView, Optional, TypeVar
from typing import Callable, Generic, List, Optional, Tuple, TypeVar

from aws_advanced_python_wrapper.utils.atomic import AtomicInt
from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict
Expand Down Expand Up @@ -46,10 +46,10 @@ def __len__(self):
def set_cleanup_interval_ns(self, interval_ns):
self._cleanup_interval_ns = interval_ns

def keys(self) -> KeysView:
def keys(self) -> List[K]:
return self._cdict.keys()

def items(self) -> ItemsView:
def items(self) -> List[Tuple[K, CacheItem[V]]]:
return self._cdict.items()

def compute_if_absent(self, key: K, mapping_func: Callable, item_expiration_ns: int) -> Optional[V]:
Expand All @@ -73,32 +73,28 @@ def _remove_and_dispose(self, key: K):
self._item_disposal_func(cache_item.item)

def _remove_if_expired(self, key: K):
item = None

def _remove_if_expired_internal(_, cache_item):
if self._should_cleanup_item(cache_item):
nonlocal item
item = cache_item.item
# Dispose while holding the lock to prevent race conditions
if self._item_disposal_func is not None:
self._item_disposal_func(cache_item.item)
return None

return cache_item

self._cdict.compute_if_present(key, _remove_if_expired_internal)

if item is None or self._item_disposal_func is None:
return

self._item_disposal_func(item)

def _should_cleanup_item(self, cache_item: CacheItem) -> bool:
if self._should_dispose_func is not None:
return perf_counter_ns() > cache_item.expiration_time and self._should_dispose_func(cache_item.item)
return perf_counter_ns() > cache_item.expiration_time

def clear(self):
for _, cache_item in self._cdict.items():
if cache_item is not None and self._item_disposal_func is not None:
self._item_disposal_func(cache_item.item)
# Dispose all items while holding the lock
if self._item_disposal_func is not None:
self._cdict.apply_if(
lambda k, v: True, # Apply to all items
lambda k, cache_item: self._item_disposal_func(cache_item.item)
)
self._cdict.clear()

def _cleanup(self):
Expand All @@ -107,7 +103,7 @@ def _cleanup(self):
return

self._cleanup_time_ns.set(current_time + self._cleanup_interval_ns)
keys = [key for key, _ in self._cdict.items()]
keys = self._cdict.keys()
for key in keys:
self._remove_if_expired(key)

Expand All @@ -129,29 +125,21 @@ def compute_if_absent_with_disposal(self, key: K, mapping_func: Callable, item_e
return None if cache_item is None else cache_item.update_expiration(item_expiration_ns).item

def _remove_if_disposable(self, key: K):
item = None

def _remove_if_disposable_internal(_, cache_item):
if self._should_dispose_func is not None and self._should_dispose_func(cache_item.item):
nonlocal item
item = cache_item.item
if self._item_disposal_func is not None:
self._item_disposal_func(cache_item.item)
return None

return cache_item

self._cdict.compute_if_present(key, _remove_if_disposable_internal)

if item is None or self._item_disposal_func is None:
return

self._item_disposal_func(item)

def _cleanup_thread_internal(self):
while True:
try:
sleep(self._cleanup_interval_ns / 1_000_000_000)
self._cleanup_time_ns.set(perf_counter_ns() + self._cleanup_interval_ns)
keys = [key for key, _ in self._cdict.items()]
keys = self._cdict.keys()
for key in keys:
try:
self._remove_if_expired(key)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_sql_alchemy_pooled_connection_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def clear_cache():

def test_connect__default_mapping__default_pool_configuration(provider, host_info, mocker, mock_conn, mock_pool):
expected_urls = {host_info.url}
expected_keys = {PoolKey(host_info.url, "user1")}
expected_keys = [PoolKey(host_info.url, "user1")]
props = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "password"})

conn = provider.connect(mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock(), host_info, props)
Expand All @@ -76,7 +76,7 @@ def test_connect__default_mapping__default_pool_configuration(provider, host_inf


def test_connect__custom_configuration_and_mapping(host_info, mocker, mock_conn, mock_pool):
expected_keys = {PoolKey(host_info.url, f"{host_info.url}+some_unique_key")}
expected_keys = [PoolKey(host_info.url, f"{host_info.url}+some_unique_key")]
props = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "password"})
attempt_creator_override_func = mocker.MagicMock()

Expand Down
Loading