Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions aws_advanced_python_wrapper/cluster_topology_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from aws_advanced_python_wrapper.utils.atomic import AtomicReference
from aws_advanced_python_wrapper.utils.cache_map import CacheMap
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.thread_safe_connection_holder import \
ThreadSafeConnectionHolder
from aws_advanced_python_wrapper.utils.utils import LogUtils

if TYPE_CHECKING:
Expand Down Expand Up @@ -86,7 +88,7 @@ def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils,
self._high_refresh_rate_nano = high_refresh_rate_nano

self._writer_host_info: AtomicReference[Optional[HostInfo]] = AtomicReference(None)
self._monitoring_connection: AtomicReference[Optional[Connection]] = AtomicReference(None)
self._monitoring_connection: ThreadSafeConnectionHolder = ThreadSafeConnectionHolder(None)

self._topology_updated = threading.Event()
self._request_to_update_topology = threading.Event()
Expand Down Expand Up @@ -123,7 +125,7 @@ def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[H
return current_hosts

if should_verify_writer:
self._close_connection_from_ref(self._monitoring_connection)
self._monitoring_connection.clear()
self._is_verified_writer_connection = False

result = self._wait_till_topology_gets_updated(timeout_sec)
Expand Down Expand Up @@ -177,7 +179,7 @@ def close(self) -> None:
self._monitor_thread.join(self.MONITOR_TERMINATION_TIMEOUT_SEC)

# Step 3: Now safe to close connections - no threads are using them
self._close_connection_from_ref(self._monitoring_connection)
self._monitoring_connection.clear()
self._close_connection_from_ref(self._host_threads_writer_connection)
self._close_connection_from_ref(self._host_threads_reader_connection)

Expand Down Expand Up @@ -220,8 +222,8 @@ def _monitor(self) -> None:
writer_connection = self._host_threads_writer_connection.get()
if (writer_connection is not None and writer_host_info is not None):
logger.debug("ClusterTopologyMonitorImpl.WriterPickedUpFromHostMonitors", self._cluster_id, writer_host_info.host)
self._close_connection_from_ref(self._monitoring_connection)
self._monitoring_connection.set(writer_connection)
# Transfer the writer connection to monitoring connection
self._monitoring_connection.set(writer_connection, close_previous=True)
self._writer_host_info.set(writer_host_info)
self._is_verified_writer_connection = True
self._high_refresh_rate_end_time_nano = (
Expand Down Expand Up @@ -259,9 +261,9 @@ def _monitor(self) -> None:
self._close_host_monitors()
self._submitted_hosts.clear()

hosts = self._fetch_topology_and_update_cache(self._monitoring_connection.get())
hosts = self._fetch_topology_and_update_cache_safe()
if not hosts:
self._close_connection_from_ref(self._monitoring_connection)
self._monitoring_connection.clear()
self._is_verified_writer_connection = False
self._writer_host_info.set(None)
continue
Expand All @@ -282,7 +284,7 @@ def _monitor(self) -> None:
finally:
self._stop.set()
self._close_host_monitors()
self._close_connection_from_ref(self._monitoring_connection)
self._monitoring_connection.clear()
logger.debug("ClusterTopologyMonitor.StopMonitoringThread", self._cluster_id, self._initial_host_info.host)

def _is_in_panic_mode(self) -> bool:
Expand All @@ -297,7 +299,7 @@ def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]:
# Try to connect to the initial host first
try:
conn = self._plugin_service.force_connect(self._initial_host_info, self._monitoring_properties)
self._monitoring_connection.set(conn)
self._monitoring_connection.set(conn, close_previous=False)
logger.debug("ClusterTopologyMonitorImpl.OpenedMonitoringConnection", self._cluster_id, self._initial_host_info.host)

try:
Expand All @@ -313,7 +315,7 @@ def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]:
except Exception:
return ()

hosts = self._fetch_topology_and_update_cache(self._monitoring_connection.get())
hosts = self._fetch_topology_and_update_cache_safe()
if writer_verified_by_this_thread:
if self._ignore_new_topology_requests_end_time_nano == -1:
self._ignore_new_topology_requests_end_time_nano = 0
Expand All @@ -322,7 +324,7 @@ def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]:
time.time_ns() + self.IGNORE_TOPOLOGY_REQUEST_NANO)

if len(hosts) == 0:
self._close_connection_from_ref(self._monitoring_connection)
self._monitoring_connection.clear()
self._is_verified_writer_connection = False
self._writer_host_info.set(None)

Expand Down Expand Up @@ -400,6 +402,16 @@ def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) ->
logger.debug("ClusterTopologyMonitorImpl.ErrorFetchingTopology", self._cluster_id, ex)
return ()

def _fetch_topology_and_update_cache_safe(self) -> Tuple[HostInfo, ...]:
"""
Safely fetch topology using ThreadSafeConnectionHolder to prevent race conditions.
The lock is held during the entire query operation.
"""
result = self._monitoring_connection.use_connection(
lambda conn: self._fetch_topology_and_update_cache(conn)
)
return result if result is not None else ()

def _query_for_topology(self, connection: Connection) -> Tuple[HostInfo, ...]:
hosts = self._topology_utils.query_for_topology(connection, self._plugin_service.driver_dialect)
if hosts is not None:
Expand Down
109 changes: 109 additions & 0 deletions aws_advanced_python_wrapper/utils/thread_safe_connection_holder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from __future__ import annotations

from threading import RLock
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from aws_advanced_python_wrapper.pep249 import Connection

from aws_advanced_python_wrapper.utils.log import Logger

logger = Logger(__name__)


class ThreadSafeConnectionHolder:
"""
Thread-safe connection container that ensures connections are properly
closed when replaced or cleared. This class prevents race conditions where
one thread might close a connection while another thread is using it.
"""

def __init__(self, initial_connection: Optional[Connection] = None):
self._connection: Optional[Connection] = initial_connection
self._lock: RLock = RLock()

def get(self) -> Optional[Connection]:
with self._lock:
return self._connection

def set(self, new_connection: Optional[Connection], close_previous: bool = True) -> None:
with self._lock:
old_connection = self._connection
self._connection = new_connection

if close_previous and old_connection is not None and old_connection != new_connection:
self._close_connection(old_connection)

def get_and_set(self, new_connection: Optional[Connection], close_previous: bool = True) -> Optional[Connection]:
with self._lock:
old_connection = self._connection
self._connection = new_connection

if close_previous and old_connection is not None and old_connection != new_connection:
self._close_connection(old_connection)

return old_connection

def compare_and_set(
self,
expected_connection: Optional[Connection],
new_connection: Optional[Connection],
close_previous: bool = True
) -> bool:
with self._lock:
if self._connection == expected_connection:
old_connection = self._connection
self._connection = new_connection

if close_previous and old_connection is not None and old_connection != new_connection:
self._close_connection(old_connection)

return True
return False

def clear(self) -> None:
self.set(None, close_previous=True)

def use_connection(self, func, *args, **kwargs):
"""
Safely use the connection within a locked context.

This method ensures the connection cannot be closed by another thread
while the provided function is executing.

:param func: Function to call with the connection as the first argument.
:param args: Additional positional arguments to pass to func.
:param kwargs: Additional keyword arguments to pass to func.
:return: The result of calling func, or None if no connection is available.

Example:
result = holder.use_connection(lambda conn: conn.cursor().execute("SELECT 1"))
"""
with self._lock:
if self._connection is None:
return None
return func(self._connection, *args, **kwargs)

def _close_connection(self, connection: Connection) -> None:
try:
if connection is not None:
connection.close()
except Exception: # ignore
pass

def __repr__(self) -> str:
with self._lock:
return f"ThreadSafeConnectionHolder(connection={self._connection})"
Loading
Loading