Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aws_advanced_python_wrapper/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
MonitoringThreadContainer
from aws_advanced_python_wrapper.thread_pool_container import \
ThreadPoolContainer
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
SlidingExpirationCacheContainer


def release_resources() -> None:
"""Release all global resources used by the wrapper."""
MonitoringThreadContainer.clean_up()
ThreadPoolContainer.release_resources()
OpenedConnectionTracker.release_resources()
SlidingExpirationCacheContainer.release_resources()
7 changes: 4 additions & 3 deletions aws_advanced_python_wrapper/connection_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Dict, Optional, Protocol, Tuple
from typing import (TYPE_CHECKING, Callable, ClassVar, Dict, Optional,
Protocol, Tuple)

if TYPE_CHECKING:
from aws_advanced_python_wrapper.database_dialect import DatabaseDialect
Expand Down Expand Up @@ -131,8 +132,8 @@ def connect(


class ConnectionProviderManager:
_lock: Lock = Lock()
_conn_provider: Optional[ConnectionProvider] = None
_lock: ClassVar[Lock] = Lock()
_conn_provider: ClassVar[Optional[ConnectionProvider]] = None

def __init__(self, default_provider: ConnectionProvider = DriverConnectionProvider()):
self._default_provider: ConnectionProvider = default_provider
Expand Down
20 changes: 12 additions & 8 deletions aws_advanced_python_wrapper/custom_endpoint_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.properties import WrapperProperties
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \
SlidingExpirationCacheWithCleanupThread
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
SlidingExpirationCacheContainer
from aws_advanced_python_wrapper.utils.telemetry.telemetry import (
TelemetryCounter, TelemetryFactory)

Expand Down Expand Up @@ -232,11 +232,8 @@ class CustomEndpointPlugin(Plugin):
or removing an instance in the custom endpoint.
"""
_SUBSCRIBED_METHODS: ClassVar[Set[str]] = {DbApiMethod.CONNECT.method_name}
_CACHE_CLEANUP_RATE_NS: ClassVar[int] = 6 * 10 ^ 10 # 1 minute
_monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, CustomEndpointMonitor]] = \
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_RATE_NS,
should_dispose_func=lambda _: True,
item_disposal_func=lambda monitor: monitor.close())
_CACHE_CLEANUP_RATE_NS: ClassVar[int] = 60_000_000_000 # 1 minute
_MONITOR_CACHE_NAME: ClassVar[str] = "custom_endpoint_monitors"

def __init__(self, plugin_service: PluginService, props: Properties):
self._plugin_service = plugin_service
Expand All @@ -255,6 +252,13 @@ def __init__(self, plugin_service: PluginService, props: Properties):
telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory()
self._wait_for_info_counter: TelemetryCounter | None = telemetry_factory.create_counter("customEndpoint.waitForInfo.counter")

self._monitors = SlidingExpirationCacheContainer.get_or_create_cache(
name=CustomEndpointPlugin._MONITOR_CACHE_NAME,
cleanup_interval_ns=CustomEndpointPlugin._CACHE_CLEANUP_RATE_NS,
should_dispose_func=lambda _: True,
item_disposal_func=lambda monitor: monitor.close()
)

CustomEndpointPlugin._SUBSCRIBED_METHODS.update(self._plugin_service.network_bound_methods)

@property
Expand Down Expand Up @@ -298,7 +302,7 @@ def _create_monitor_if_absent(self, props: Properties) -> CustomEndpointMonitor:
host_info = cast('HostInfo', self._custom_endpoint_host_info)
endpoint_id = cast('str', self._custom_endpoint_id)
region = cast('str', self._region)
monitor = CustomEndpointPlugin._monitors.compute_if_absent(
monitor = self._monitors.compute_if_absent(
host_info.host,
lambda key: CustomEndpointMonitor(
self._plugin_service,
Expand Down
3 changes: 2 additions & 1 deletion aws_advanced_python_wrapper/database_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ def __init__(self, props: Properties, rds_helper: Optional[RdsUtils] = None):
self._can_update: bool = False
self._dialect: DatabaseDialect = UnknownDatabaseDialect()
self._dialect_code: DialectCode = DialectCode.UNKNOWN
self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name)

@staticmethod
def get_custom_dialect():
Expand Down Expand Up @@ -814,7 +815,7 @@ def query_for_dialect(self, url: str, host_info: Optional[HostInfo], conn: Conne
timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get(self._props)
try:
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
ThreadPoolContainer.get_thread_pool(DatabaseDialectManager._executor_name),
self._thread_pool,
timeout_sec,
driver_dialect,
conn)(dialect_candidate.is_dialect)
Expand Down
3 changes: 2 additions & 1 deletion aws_advanced_python_wrapper/driver_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class DriverDialect(ABC):

def __init__(self, props: Properties):
self._props = props
self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name)

@property
def driver_name(self):
Expand Down Expand Up @@ -138,7 +139,7 @@ def execute(

if exec_timeout > 0:
try:
execute_with_timeout = timeout(ThreadPoolContainer.get_thread_pool(DriverDialect._executor_name), exec_timeout)(exec_func)
execute_with_timeout = timeout(self._thread_pool, exec_timeout)(exec_func)
return execute_with_timeout()
except TimeoutError as e:
raise QueryTimeoutError(Messages.get_formatted("DriverDialect.ExecuteTimeout", method_name)) from e
Expand Down
34 changes: 21 additions & 13 deletions aws_advanced_python_wrapper/fastest_response_strategy_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \
SlidingExpirationCacheWithCleanupThread
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
SlidingExpirationCacheContainer
from aws_advanced_python_wrapper.utils.telemetry.telemetry import (
TelemetryContext, TelemetryFactory, TelemetryGauge, TelemetryTraceLevel)

Expand Down Expand Up @@ -59,7 +59,7 @@ def __init__(self, plugin_service: PluginService, props: Properties):
self._properties = props
self._host_response_time_service: HostResponseTimeService = \
HostResponseTimeService(plugin_service, props, WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props))
self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props) * 10 ^ 6
self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props) * 1_000_000
self._random_host_selector = RandomHostSelector()
self._cached_fastest_response_host_by_role: CacheMap[str, HostInfo] = CacheMap()
self._hosts: Tuple[HostInfo, ...] = ()
Expand Down Expand Up @@ -278,21 +278,29 @@ def _open_connection(self):


class HostResponseTimeService:
_CACHE_EXPIRATION_NS: int = 6 * 10 ^ 11 # 10 minutes
_CACHE_CLEANUP_NS: int = 6 * 10 ^ 10 # 1 minute
_lock: Lock = Lock()
_monitoring_hosts: ClassVar[SlidingExpirationCacheWithCleanupThread[str, HostResponseTimeMonitor]] = \
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NS,
should_dispose_func=lambda monitor: True,
item_disposal_func=lambda monitor: HostResponseTimeService._monitor_close(monitor))
_CACHE_EXPIRATION_NS: ClassVar[int] = 10 * 60_000_000_000 # 10 minutes
_CACHE_CLEANUP_NS: ClassVar[int] = 60_000_000_000 # 1 minute
_CACHE_NAME: ClassVar[str] = "host_response_time_monitors"
_lock: ClassVar[Lock] = Lock()

def __init__(self, plugin_service: PluginService, props: Properties, interval_ms: int):
self._plugin_service = plugin_service
self._properties = props
self._interval_ms = interval_ms
self._hosts: Tuple[HostInfo, ...] = ()
self._telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory()
self._host_count_gauge: TelemetryGauge | None = self._telemetry_factory.create_gauge("frt.hosts.count", lambda: len(self._monitoring_hosts))

self._monitoring_hosts = SlidingExpirationCacheContainer.get_or_create_cache(
name=HostResponseTimeService._CACHE_NAME,
cleanup_interval_ns=HostResponseTimeService._CACHE_CLEANUP_NS,
should_dispose_func=lambda monitor: True,
item_disposal_func=lambda monitor: HostResponseTimeService._monitor_close(monitor)
)

self._host_count_gauge: TelemetryGauge | None = self._telemetry_factory.create_gauge(
"frt.hosts.count",
lambda: len(self._monitoring_hosts)
)

@property
def hosts(self) -> Tuple[HostInfo, ...]:
Expand All @@ -310,7 +318,7 @@ def _monitor_close(monitor: HostResponseTimeMonitor):
pass

def get_response_time(self, host_info: HostInfo) -> int:
monitor: Optional[HostResponseTimeMonitor] = HostResponseTimeService._monitoring_hosts.get(host_info.url)
monitor: Optional[HostResponseTimeMonitor] = self._monitoring_hosts.get(host_info.url)
if monitor is None:
return MAX_VALUE
return monitor.response_time
Expand All @@ -327,4 +335,4 @@ def set_hosts(self, new_hosts: Tuple[HostInfo, ...]) -> None:
self._plugin_service,
host,
self._properties,
self._interval_ms), self._CACHE_EXPIRATION_NS)
self._interval_ms), HostResponseTimeService._CACHE_EXPIRATION_NS)
34 changes: 17 additions & 17 deletions aws_advanced_python_wrapper/host_list_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
ClusterTopologyMonitor, ClusterTopologyMonitorImpl)
from aws_advanced_python_wrapper.utils.decorators import \
preserve_transaction_status_with_timeout
from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \
SlidingExpirationCacheWithCleanupThread
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
SlidingExpirationCacheContainer

if TYPE_CHECKING:
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
Expand Down Expand Up @@ -476,6 +476,7 @@ def __init__(self, dialect: db_dialect.TopologyAwareDatabaseDialect, props: Prop

self.instance_template: HostInfo = instance_template
self._max_timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get_int(props)
self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name)

def _validate_host_pattern(self, host: str):
if not self._rds_utils.is_dns_pattern_valid(host):
Expand Down Expand Up @@ -507,7 +508,7 @@ def query_for_topology(
an empty tuple will be returned.
"""
query_for_topology_func_with_timeout = preserve_transaction_status_with_timeout(
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, conn)(self._query_for_topology)
self._thread_pool, self._max_timeout_sec, driver_dialect, conn)(self._query_for_topology)
x = query_for_topology_func_with_timeout(conn)
return x

Expand Down Expand Up @@ -570,7 +571,7 @@ def create_host(
def get_host_role(self, connection: Connection, driver_dialect: DriverDialect) -> HostRole:
try:
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, connection)(self._get_host_role)
self._thread_pool, self._max_timeout_sec, driver_dialect, connection)(self._get_host_role)
result = cursor_execute_func_with_timeout(connection)
if result is not None:
is_reader = result[0]
Expand All @@ -593,7 +594,7 @@ def get_host_id(self, connection: Connection, driver_dialect: DriverDialect) ->
"""

cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, connection)(self._get_host_id)
self._thread_pool, self._max_timeout_sec, driver_dialect, connection)(self._get_host_id)
result = cursor_execute_func_with_timeout(connection)
if result:
host_id: str = result[0]
Expand All @@ -608,7 +609,7 @@ def _get_host_id(self, conn: Connection):
def get_writer_host_if_connected(self, connection: Connection, driver_dialect: DriverDialect) -> Optional[str]:
try:
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, connection)(self._get_writer_id)
self._thread_pool, self._max_timeout_sec, driver_dialect, connection)(self._get_writer_id)
result = cursor_execute_func_with_timeout(connection)
if result:
host_id: str = result[0]
Expand Down Expand Up @@ -752,13 +753,9 @@ def _create_multi_az_host(self, record: Tuple, writer_id: str) -> HostInfo:


class MonitoringRdsHostListProvider(RdsHostListProvider):
_CACHE_CLEANUP_NANO = 1 * 60 * 1_000_000_000 # 1 minute
_MONITOR_CLEANUP_NANO = 15 * 60 * 1_000_000_000 # 15 minutes

_monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, ClusterTopologyMonitor]] = \
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NANO,
should_dispose_func=lambda monitor: monitor.can_dispose(),
item_disposal_func=lambda monitor: monitor.close())
_CACHE_CLEANUP_NANO: ClassVar[int] = 1 * 60 * 1_000_000_000 # 1 minute
_MONITOR_CLEANUP_NANO: ClassVar[int] = 15 * 60 * 1_000_000_000 # 15 minutes
_MONITOR_CACHE_NAME: ClassVar[str] = "cluster_topology_monitors"

def __init__(
self,
Expand All @@ -772,6 +769,13 @@ def __init__(
self._high_refresh_rate_ns = (
WrapperProperties.CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.get_int(self._props) * 1_000_000)

self._monitors = SlidingExpirationCacheContainer.get_or_create_cache(
name=MonitoringRdsHostListProvider._MONITOR_CACHE_NAME,
cleanup_interval_ns=MonitoringRdsHostListProvider._CACHE_CLEANUP_NANO,
should_dispose_func=lambda monitor: monitor.can_dispose(),
item_disposal_func=lambda monitor: monitor.close()
)

def _get_monitor(self) -> Optional[ClusterTopologyMonitor]:
return self._monitors.compute_if_absent_with_disposal(self.get_cluster_id(),
lambda k: ClusterTopologyMonitorImpl(
Expand Down Expand Up @@ -803,7 +807,3 @@ def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int)
return ()

return monitor.force_refresh(should_verify_writer, timeout_sec)

@staticmethod
def release_resources():
MonitoringRdsHostListProvider._monitors.clear()
6 changes: 4 additions & 2 deletions aws_advanced_python_wrapper/host_monitoring_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,9 @@ class MonitoringThreadContainer:
_tasks_map: ConcurrentDict[Monitor, Future] = ConcurrentDict()
_executor_name: ClassVar[str] = "MonitoringThreadContainerExecutor"

def __init__(self):
self._thread_pool = ThreadPoolContainer.get_thread_pool(self._executor_name)

# This logic ensures that this class is a Singleton
def __new__(cls, *args, **kwargs):
if cls._instance is None:
Expand Down Expand Up @@ -605,8 +608,7 @@ def _get_or_create_monitor(_) -> Monitor:
raise AwsWrapperError(Messages.get("MonitoringThreadContainer.SupplierMonitorNone"))
self._tasks_map.compute_if_absent(
supplied_monitor,
lambda _: ThreadPoolContainer.get_thread_pool(MonitoringThreadContainer._executor_name)
.submit(supplied_monitor.run))
lambda _: self._thread_pool.submit(supplied_monitor.run))
return supplied_monitor

if monitor is None:
Expand Down
19 changes: 11 additions & 8 deletions aws_advanced_python_wrapper/host_monitoring_v2_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
PropertiesUtils,
WrapperProperties)
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \
SlidingExpirationCacheWithCleanupThread
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
SlidingExpirationCacheContainer
from aws_advanced_python_wrapper.utils.telemetry.telemetry import (
TelemetryCounter, TelemetryFactory, TelemetryTraceLevel)

Expand Down Expand Up @@ -450,19 +450,22 @@ def close(self) -> None:

class MonitorServiceV2:
# 1 Minute to Nanoseconds
_CACHE_CLEANUP_NANO = 1 * 60 * 1_000_000_000

_monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, HostMonitorV2]] = \
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NANO,
should_dispose_func=lambda monitor: monitor.can_dispose(),
item_disposal_func=lambda monitor: monitor.close())
_CACHE_CLEANUP_NANO: ClassVar[int] = 1 * 60 * 1_000_000_000
_MONITOR_CACHE_NAME: ClassVar[str] = "host_monitors_v2"

def __init__(self, plugin_service: PluginService):
self._plugin_service: PluginService = plugin_service

telemetry_factory = self._plugin_service.get_telemetry_factory()
self._aborted_connections_counter = telemetry_factory.create_counter("efm2.connections.aborted")

self._monitors = SlidingExpirationCacheContainer.get_or_create_cache(
name=MonitorServiceV2._MONITOR_CACHE_NAME,
cleanup_interval_ns=MonitorServiceV2._CACHE_CLEANUP_NANO,
should_dispose_func=lambda monitor: monitor.can_dispose(),
item_disposal_func=lambda monitor: monitor.close()
)

def start_monitoring(
self,
conn: Connection,
Expand Down
Loading
Loading