Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 179 additions & 49 deletions aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@

from __future__ import annotations

import threading
from threading import Thread
from typing import (TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Optional,
Set, Tuple)
from time import perf_counter_ns
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, FrozenSet,
Optional, Set)

from aws_advanced_python_wrapper.utils.notifications import HostEvent
from aws_advanced_python_wrapper.utils.utils import Utils

if TYPE_CHECKING:
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.pep249 import Connection

from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
Expand All @@ -29,7 +35,6 @@
from _weakrefset import WeakSet

from aws_advanced_python_wrapper.errors import FailoverError
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.log import Logger
Expand All @@ -39,8 +44,74 @@


class OpenedConnectionTracker:
_opened_connections: Dict[str, WeakSet] = {}
_rds_utils = RdsUtils()
_opened_connections: ClassVar[Dict[str, WeakSet]] = {}
_lock: ClassVar[threading.Lock] = threading.Lock()
_rds_utils: ClassVar[RdsUtils] = RdsUtils()
_prune_thread: ClassVar[Optional[Thread]] = None
_prune_thread_started: ClassVar[bool] = False
_shutdown_event: ClassVar[threading.Event] = threading.Event()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is never set so thread is never shut down. Maybe we should have this in the release_resources method in the wrapper.py and change this in a release_resources() method or something like that).

_safe_to_check_closed_classes: ClassVar[Set[str]] = {"psycopg"}
_default_sleep_time: ClassVar[int] = 30

@classmethod
def _start_prune_thread(cls):
with cls._lock:
if not cls._prune_thread_started:
cls._prune_thread_started = True
cls._prune_thread = Thread(daemon=True, target=cls._prune_connections_loop)
cls._prune_thread.start()

@classmethod
def release_resources(cls):
cls._shutdown_event.set()
with cls._lock:
thread_to_join = cls._prune_thread
if thread_to_join is not None:
thread_to_join.join()
with cls._lock:
cls._opened_connections.clear()

@classmethod
def _prune_connections_loop(cls):
while not cls._shutdown_event.is_set():
try:
cls._prune_connections()
if cls._shutdown_event.wait(timeout=cls._default_sleep_time):
break
except Exception:
pass

@classmethod
def _prune_connections(cls):
with cls._lock:
opened_connections = list(cls._opened_connections.items())

to_remove_by_host = {}
for host, conn_set in opened_connections:
to_remove = []
for conn in list(conn_set):
if conn is None:
to_remove.append(conn)
else:
try:
# The following classes do not check connection validity via a DB server call
# so it is safe to check whether connection is already closed.
if any(safe_class in conn.__module__ for safe_class in cls._safe_to_check_closed_classes) and conn.is_closed():
to_remove.append(conn)
except Exception:
pass

if to_remove:
to_remove_by_host[host] = (conn_set, to_remove)

with cls._lock:
for host, (conn_set, to_remove) in to_remove_by_host.items():
for conn in to_remove:
conn_set.discard(conn)

# Remove empty connection sets
if not conn_set and host in cls._opened_connections:
del cls._opened_connections[host]

def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection):
"""
Expand All @@ -56,8 +127,8 @@ def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection):
self._track_connection(host_info.as_alias(), conn)
return

instance_endpoint: Optional[str] = next((alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))),
None)
instance_endpoint: Optional[str] = next(
(alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), None)
if not instance_endpoint:
logger.debug("OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet")
return
Expand All @@ -73,7 +144,7 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host:
"""

if host_info:
self.invalidate_all_connections(host=frozenset(host_info.as_alias()))
self.invalidate_all_connections(host=frozenset([host_info.as_alias()]))
self.invalidate_all_connections(host=host_info.as_aliases())
return

Expand All @@ -89,27 +160,42 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host:
if not instance_endpoint:
return

connection_set: Optional[WeakSet] = self._opened_connections.get(instance_endpoint)
if connection_set is not None:
with self._lock:
connection_set: Optional[WeakSet] = self._opened_connections.get(instance_endpoint)
connections_list = list(connection_set) if connection_set is not None else None

if connections_list is not None:
self._log_connection_set(instance_endpoint, connection_set)
self._invalidate_connections(connection_set)
self._invalidate_connections(connections_list)

def _track_connection(self, instance_endpoint: str, conn: Connection):
connection_set: Optional[WeakSet] = self._opened_connections.get(instance_endpoint)
if connection_set is None:
connection_set = WeakSet()
connection_set.add(conn)
self._opened_connections[instance_endpoint] = connection_set
def remove_connection_tracking(self, host_info: HostInfo, connection: Connection | None):
if not connection:
return

if self._rds_utils.is_rds_instance(host_info.host):
host = host_info.as_alias()
else:
connection_set.add(conn)
host = next((alias for alias in host_info.as_aliases()
if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), "")

if not host:
return

with self._lock:
connection_set = self._opened_connections.get(host)
if connection_set:
connection_set.discard(connection)

def _track_connection(self, instance_endpoint: str, conn: Connection):
with self._lock:
connection_set = self._opened_connections.setdefault(instance_endpoint, WeakSet())
connection_set.add(conn)
self._start_prune_thread()
self.log_opened_connections()

@staticmethod
def _task(connection_set: WeakSet):
while connection_set is not None and len(connection_set) > 0:
conn_reference = connection_set.pop()

def _task(connections_list: list):
for conn_reference in connections_list:
if conn_reference is None:
continue

Expand All @@ -119,37 +205,38 @@ def _task(connection_set: WeakSet):
# Swallow this exception, current connection should be useless anyway
pass

def _invalidate_connections(self, connection_set: WeakSet):
def _invalidate_connections(self, connections_list: list):
invalidate_connection_thread: Thread = Thread(daemon=True, target=self._task,
args=[connection_set]) # type: ignore
args=[connections_list]) # type: ignore
invalidate_connection_thread.start()

def log_opened_connections(self):
msg = ""
for key, conn_set in self._opened_connections.items():
conn = ""
for item in list(conn_set):
conn += f"\n\t\t{item}"
with self._lock:
opened_connections = [(key, list(conn_set)) for key, conn_set in self._opened_connections.items()]

msg += f"\t[{key} : {conn}]"
msg_parts = []
for key, conn_list in opened_connections:
conn_parts = [f"\n\t\t{item}" for item in conn_list]
conn = "".join(conn_parts)
msg_parts.append(f"\t[{key} : {conn}]")

msg = "".join(msg_parts)
return logger.debug("OpenedConnectionTracker.OpenedConnectionsTracked", msg)

def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]):
if conn_set is None or len(conn_set) == 0:
return

conn = ""
for item in list(conn_set):
conn += f"\n\t\t{item}"

conn_parts = [f"\n\t\t{item}" for item in list(conn_set)]
conn = "".join(conn_parts)
msg = host + f"[{conn}\n]"
logger.debug("OpenedConnectionTracker.InvalidatingConnections", msg)


class AuroraConnectionTrackerPlugin(Plugin):
_current_writer: Optional[HostInfo] = None
_need_update_current_writer: bool = False
_host_list_refresh_end_time_nano: ClassVar[int] = 0
_refresh_lock: ClassVar[threading.Lock] = threading.Lock()
_TOPOLOGY_CHANGES_EXPECTED_TIME_NANO: ClassVar[int] = 3 * 60 * 1_000_000_000 # 3 minutes

@property
def subscribed_methods(self) -> Set[str]:
Expand All @@ -164,6 +251,8 @@ def __init__(self,
self._props = props
self._rds_utils = rds_utils
self._tracker = tracker
self._current_writer: Optional[HostInfo] = None
self._need_update_current_writer: bool = False
self._subscribed_methods: Set[str] = {DbApiMethod.CONNECT.method_name,
DbApiMethod.CONNECTION_CLOSE.method_name,
DbApiMethod.CONNECT.method_name,
Expand Down Expand Up @@ -192,26 +281,67 @@ def connect(
return conn

def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any:
current_host = self._plugin_service.current_host_info
if self._current_writer is None or self._need_update_current_writer:
self._current_writer = self._get_writer(self._plugin_service.all_hosts)
self._current_writer = Utils.get_writer(self._plugin_service.all_hosts)
self._need_update_current_writer = False

try:
return execute_func()
if not method_name == DbApiMethod.CONNECTION_CLOSE.method_name:
need_refresh_host_lists = False
with AuroraConnectionTrackerPlugin._refresh_lock:
local_host_list_refresh_end_time_nano = AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano
if local_host_list_refresh_end_time_nano > 0:
if local_host_list_refresh_end_time_nano > perf_counter_ns():
# The time specified in hostListRefreshThresholdTimeNano isn't yet reached.
# Need to continue to refresh host list.
need_refresh_host_lists = True
else:
# The time specified in hostListRefreshThresholdTimeNano is reached, and we can stop further refreshes
# of host list.
AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano = 0

if self._need_update_current_writer or need_refresh_host_lists:
# Calling this method may effectively close/abort a current connection
self._check_writer_changed(need_refresh_host_lists)

result = execute_func()
if method_name == DbApiMethod.CONNECTION_CLOSE.method_name:
self._tracker.remove_connection_tracking(current_host, self._plugin_service.current_connection)
return result

except Exception as e:
# Check that e is a FailoverError and that the writer has changed
if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.all_hosts) != self._current_writer:
self._tracker.invalidate_all_connections(host_info=self._current_writer)
self._tracker.log_opened_connections()
self._need_update_current_writer = True
raise e
if isinstance(e, FailoverError):
with AuroraConnectionTrackerPlugin._refresh_lock:
AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano = (
perf_counter_ns() + AuroraConnectionTrackerPlugin._TOPOLOGY_CHANGES_EXPECTED_TIME_NANO)
# Calling this method may effectively close/abort a current connection
self._check_writer_changed(True)
raise

def _check_writer_changed(self, need_refresh_host_lists: bool):
if need_refresh_host_lists:
self._plugin_service.refresh_host_list()

host_info_after_failover = Utils.get_writer(self._plugin_service.all_hosts)
if host_info_after_failover is None:
return

if self._current_writer is None:
self._current_writer = host_info_after_failover
self._need_update_current_writer = False
elif not self._current_writer.get_host_and_port() == host_info_after_failover.get_host_and_port():
self._tracker.invalidate_all_connections(host_info=self._current_writer)
self._tracker.log_opened_connections()
self._current_writer = host_info_after_failover
self._need_update_current_writer = False

def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]:
for host in hosts:
if host.role == HostRole.WRITER:
return host
return None
def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]):
for node, node_changes in changes.items():
if HostEvent.CONVERTED_TO_READER in node_changes:
self._tracker.invalidate_all_connections(host=frozenset([node]))
if HostEvent.CONVERTED_TO_WRITER in node_changes:
self._need_update_current_writer = True


class AuroraConnectionTrackerPluginFactory(PluginFactory):
Expand Down
3 changes: 3 additions & 0 deletions aws_advanced_python_wrapper/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from aws_advanced_python_wrapper.aurora_connection_tracker_plugin import \
OpenedConnectionTracker
from aws_advanced_python_wrapper.host_monitoring_plugin import \
MonitoringThreadContainer
from aws_advanced_python_wrapper.thread_pool_container import \
Expand All @@ -22,3 +24,4 @@ def release_resources() -> None:
"""Release all global resources used by the wrapper."""
MonitoringThreadContainer.clean_up()
ThreadPoolContainer.release_resources()
OpenedConnectionTracker.release_resources()
6 changes: 3 additions & 3 deletions aws_advanced_python_wrapper/failover_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,12 @@ def _failover_writer(self):

writer_host = self._get_writer(result.topology)
allowed_hosts = self._plugin_service.hosts
allowed_hostnames = [host.host for host in allowed_hosts]
if writer_host.host not in allowed_hostnames:
allowed_hostnames = [host.get_host_and_port() for host in allowed_hosts]
if writer_host.get_host_and_port() not in allowed_hostnames:
raise FailoverFailedError(
Messages.get_formatted(
"FailoverPlugin.NewWriterNotAllowed",
"<null>" if writer_host is None else writer_host.host,
"<null>" if writer_host is None else writer_host.get_host_and_port(),
LogUtils.log_topology(allowed_hosts)))

self._plugin_service.set_current_connection(result.new_connection, writer_host)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from __future__ import annotations

import time
from copy import copy
from dataclasses import dataclass
from datetime import datetime
from threading import Event, Lock, Thread
from time import sleep
from typing import (TYPE_CHECKING, Callable, ClassVar, Dict, List, Optional,
Expand Down Expand Up @@ -96,7 +96,7 @@ def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Op

# Found a fastest host. Let's find it in the latest topology.
for host in self._plugin_service.hosts:
if host == fastest_response_host:
if host.get_host_and_port() == fastest_response_host.get_host_and_port():
# found the fastest host in the topology
return host
# It seems that the fastest cached host isn't in the latest topology.
Expand Down Expand Up @@ -196,7 +196,7 @@ def close(self):
logger.debug("HostResponseTimeMonitor.Stopped", self._host_info.host)

def _get_current_time(self):
return datetime.now().microsecond / 1000 # milliseconds
return time.perf_counter() * 1000 # milliseconds

def run(self):
context: TelemetryContext = self._telemetry_factory.open_telemetry_context(
Expand Down
Loading