diff --git a/aws_advanced_python_wrapper/cluster_topology_monitor.py b/aws_advanced_python_wrapper/cluster_topology_monitor.py index b1acf26f..4a15bb35 100644 --- a/aws_advanced_python_wrapper/cluster_topology_monitor.py +++ b/aws_advanced_python_wrapper/cluster_topology_monitor.py @@ -24,6 +24,8 @@ from aws_advanced_python_wrapper.utils.atomic import AtomicReference from aws_advanced_python_wrapper.utils.cache_map import CacheMap from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.thread_safe_connection_holder import \ + ThreadSafeConnectionHolder from aws_advanced_python_wrapper.utils.utils import LogUtils if TYPE_CHECKING: @@ -86,7 +88,7 @@ def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils, self._high_refresh_rate_nano = high_refresh_rate_nano self._writer_host_info: AtomicReference[Optional[HostInfo]] = AtomicReference(None) - self._monitoring_connection: AtomicReference[Optional[Connection]] = AtomicReference(None) + self._monitoring_connection: ThreadSafeConnectionHolder = ThreadSafeConnectionHolder(None) self._topology_updated = threading.Event() self._request_to_update_topology = threading.Event() @@ -123,7 +125,7 @@ def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[H return current_hosts if should_verify_writer: - self._close_connection_from_ref(self._monitoring_connection) + self._monitoring_connection.clear() self._is_verified_writer_connection = False result = self._wait_till_topology_gets_updated(timeout_sec) @@ -177,7 +179,7 @@ def close(self) -> None: self._monitor_thread.join(self.MONITOR_TERMINATION_TIMEOUT_SEC) # Step 3: Now safe to close connections - no threads are using them - self._close_connection_from_ref(self._monitoring_connection) + self._monitoring_connection.clear() self._close_connection_from_ref(self._host_threads_writer_connection) self._close_connection_from_ref(self._host_threads_reader_connection) @@ -220,8 +222,8 @@ def _monitor(self) -> None: writer_connection = self._host_threads_writer_connection.get() if (writer_connection is not None and writer_host_info is not None): logger.debug("ClusterTopologyMonitorImpl.WriterPickedUpFromHostMonitors", self._cluster_id, writer_host_info.host) - self._close_connection_from_ref(self._monitoring_connection) - self._monitoring_connection.set(writer_connection) + # Transfer the writer connection to monitoring connection + self._monitoring_connection.set(writer_connection, close_previous=True) self._writer_host_info.set(writer_host_info) self._is_verified_writer_connection = True self._high_refresh_rate_end_time_nano = ( @@ -259,9 +261,9 @@ def _monitor(self) -> None: self._close_host_monitors() self._submitted_hosts.clear() - hosts = self._fetch_topology_and_update_cache(self._monitoring_connection.get()) + hosts = self._fetch_topology_and_update_cache_safe() if not hosts: - self._close_connection_from_ref(self._monitoring_connection) + self._monitoring_connection.clear() self._is_verified_writer_connection = False self._writer_host_info.set(None) continue @@ -282,7 +284,7 @@ def _monitor(self) -> None: finally: self._stop.set() self._close_host_monitors() - self._close_connection_from_ref(self._monitoring_connection) + self._monitoring_connection.clear() logger.debug("ClusterTopologyMonitor.StopMonitoringThread", self._cluster_id, self._initial_host_info.host) def _is_in_panic_mode(self) -> bool: @@ -297,7 +299,7 @@ def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]: # Try to connect to the initial host first try: conn = self._plugin_service.force_connect(self._initial_host_info, self._monitoring_properties) - self._monitoring_connection.set(conn) + self._monitoring_connection.set(conn, close_previous=False) logger.debug("ClusterTopologyMonitorImpl.OpenedMonitoringConnection", self._cluster_id, self._initial_host_info.host) try: @@ -313,7 +315,7 @@ def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]: except Exception: return () - hosts = self._fetch_topology_and_update_cache(self._monitoring_connection.get()) + hosts = self._fetch_topology_and_update_cache_safe() if writer_verified_by_this_thread: if self._ignore_new_topology_requests_end_time_nano == -1: self._ignore_new_topology_requests_end_time_nano = 0 @@ -322,7 +324,7 @@ def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]: time.time_ns() + self.IGNORE_TOPOLOGY_REQUEST_NANO) if len(hosts) == 0: - self._close_connection_from_ref(self._monitoring_connection) + self._monitoring_connection.clear() self._is_verified_writer_connection = False self._writer_host_info.set(None) @@ -400,6 +402,16 @@ def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) -> logger.debug("ClusterTopologyMonitorImpl.ErrorFetchingTopology", self._cluster_id, ex) return () + def _fetch_topology_and_update_cache_safe(self) -> Tuple[HostInfo, ...]: + """ + Safely fetch topology using ThreadSafeConnectionHolder to prevent race conditions. + The lock is held during the entire query operation. + """ + result = self._monitoring_connection.use_connection( + lambda conn: self._fetch_topology_and_update_cache(conn) + ) + return result if result is not None else () + def _query_for_topology(self, connection: Connection) -> Tuple[HostInfo, ...]: hosts = self._topology_utils.query_for_topology(connection, self._plugin_service.driver_dialect) if hosts is not None: diff --git a/aws_advanced_python_wrapper/utils/thread_safe_connection_holder.py b/aws_advanced_python_wrapper/utils/thread_safe_connection_holder.py new file mode 100644 index 00000000..baa0aa61 --- /dev/null +++ b/aws_advanced_python_wrapper/utils/thread_safe_connection_holder.py @@ -0,0 +1,109 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from __future__ import annotations + +from threading import RLock +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.pep249 import Connection + +from aws_advanced_python_wrapper.utils.log import Logger + +logger = Logger(__name__) + + +class ThreadSafeConnectionHolder: + """ + Thread-safe connection container that ensures connections are properly + closed when replaced or cleared. This class prevents race conditions where + one thread might close a connection while another thread is using it. + """ + + def __init__(self, initial_connection: Optional[Connection] = None): + self._connection: Optional[Connection] = initial_connection + self._lock: RLock = RLock() + + def get(self) -> Optional[Connection]: + with self._lock: + return self._connection + + def set(self, new_connection: Optional[Connection], close_previous: bool = True) -> None: + with self._lock: + old_connection = self._connection + self._connection = new_connection + + if close_previous and old_connection is not None and old_connection != new_connection: + self._close_connection(old_connection) + + def get_and_set(self, new_connection: Optional[Connection], close_previous: bool = True) -> Optional[Connection]: + with self._lock: + old_connection = self._connection + self._connection = new_connection + + if close_previous and old_connection is not None and old_connection != new_connection: + self._close_connection(old_connection) + + return old_connection + + def compare_and_set( + self, + expected_connection: Optional[Connection], + new_connection: Optional[Connection], + close_previous: bool = True + ) -> bool: + with self._lock: + if self._connection == expected_connection: + old_connection = self._connection + self._connection = new_connection + + if close_previous and old_connection is not None and old_connection != new_connection: + self._close_connection(old_connection) + + return True + return False + + def clear(self) -> None: + self.set(None, close_previous=True) + + def use_connection(self, func, *args, **kwargs): + """ + Safely use the connection within a locked context. + + This method ensures the connection cannot be closed by another thread + while the provided function is executing. + + :param func: Function to call with the connection as the first argument. + :param args: Additional positional arguments to pass to func. + :param kwargs: Additional keyword arguments to pass to func. + :return: The result of calling func, or None if no connection is available. + + Example: + result = holder.use_connection(lambda conn: conn.cursor().execute("SELECT 1")) + """ + with self._lock: + if self._connection is None: + return None + return func(self._connection, *args, **kwargs) + + def _close_connection(self, connection: Connection) -> None: + try: + if connection is not None: + connection.close() + except Exception: # ignore + pass + + def __repr__(self) -> str: + with self._lock: + return f"ThreadSafeConnectionHolder(connection={self._connection})" diff --git a/tests/unit/test_thread_safe_connection_holder.py b/tests/unit/test_thread_safe_connection_holder.py new file mode 100644 index 00000000..903aae0c --- /dev/null +++ b/tests/unit/test_thread_safe_connection_holder.py @@ -0,0 +1,181 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import threading +import time +from unittest.mock import MagicMock + +from aws_advanced_python_wrapper.utils.thread_safe_connection_holder import \ + ThreadSafeConnectionHolder + + +class TestThreadSafeConnectionHolder: + def test_get_set(self): + """Test basic get and set operations.""" + holder = ThreadSafeConnectionHolder() + assert holder.get() is None + + mock_conn = MagicMock() + holder.set(mock_conn, close_previous=False) + assert holder.get() == mock_conn + + def test_set_closes_previous(self): + """Test that set closes the previous connection when requested.""" + holder = ThreadSafeConnectionHolder() + + mock_conn1 = MagicMock() + mock_conn2 = MagicMock() + + holder.set(mock_conn1, close_previous=False) + holder.set(mock_conn2, close_previous=True) + + mock_conn1.close.assert_called_once() + mock_conn2.close.assert_not_called() + + def test_set_does_not_close_previous_when_disabled(self): + """Test that set doesn't close previous connection when close_previous=False.""" + holder = ThreadSafeConnectionHolder() + + mock_conn1 = MagicMock() + mock_conn2 = MagicMock() + + holder.set(mock_conn1, close_previous=False) + holder.set(mock_conn2, close_previous=False) + + mock_conn1.close.assert_not_called() + mock_conn2.close.assert_not_called() + + def test_get_and_set(self): + """Test get_and_set returns old value and sets new value.""" + holder = ThreadSafeConnectionHolder() + + mock_conn1 = MagicMock() + mock_conn2 = MagicMock() + + holder.set(mock_conn1, close_previous=False) + old_conn = holder.get_and_set(mock_conn2, close_previous=False) + + assert old_conn == mock_conn1 + assert holder.get() == mock_conn2 + + def test_compare_and_set_success(self): + """Test compare_and_set succeeds when expected value matches.""" + holder = ThreadSafeConnectionHolder() + + mock_conn1 = MagicMock() + mock_conn2 = MagicMock() + + holder.set(mock_conn1, close_previous=False) + result = holder.compare_and_set(mock_conn1, mock_conn2, close_previous=False) + + assert result is True + assert holder.get() == mock_conn2 + + def test_compare_and_set_failure(self): + """Test compare_and_set fails when expected value doesn't match.""" + holder = ThreadSafeConnectionHolder() + + mock_conn1 = MagicMock() + mock_conn2 = MagicMock() + mock_conn3 = MagicMock() + + holder.set(mock_conn1, close_previous=False) + result = holder.compare_and_set(mock_conn2, mock_conn3, close_previous=False) + + assert result is False + assert holder.get() == mock_conn1 + + def test_clear(self): + """Test clear removes and closes connection.""" + holder = ThreadSafeConnectionHolder() + + mock_conn = MagicMock() + holder.set(mock_conn, close_previous=False) + holder.clear() + + assert holder.get() is None + mock_conn.close.assert_called_once() + + def test_use_connection(self): + """Test use_connection safely executes function with connection.""" + holder = ThreadSafeConnectionHolder() + + mock_conn = MagicMock() + mock_conn.cursor.return_value.execute.return_value = "result" + holder.set(mock_conn, close_previous=False) + + result = holder.use_connection(lambda conn: conn.cursor().execute("SELECT 1")) + + assert result == "result" + mock_conn.cursor.assert_called_once() + + def test_use_connection_with_none(self): + """Test use_connection returns None when no connection is set.""" + holder = ThreadSafeConnectionHolder() + + result = holder.use_connection(lambda conn: conn.cursor()) + + assert result is None + + def test_thread_safety_concurrent_set(self): + """Test that concurrent set operations are thread-safe.""" + holder = ThreadSafeConnectionHolder() + connections = [MagicMock() for _ in range(10)] + + def set_connection(conn): + holder.set(conn, close_previous=False) + time.sleep(0.001) + + threads = [threading.Thread(target=set_connection, args=(conn,)) for conn in connections] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # Should have one of the connections set + final_conn = holder.get() + assert final_conn in connections + + def test_thread_safety_race_condition(self): + """Test that the race condition is prevented - connection can't be closed while in use.""" + holder = ThreadSafeConnectionHolder() + mock_conn = MagicMock() + holder.set(mock_conn, close_previous=False) + + results = [] + errors = [] + + def use_connection(): + try: + # This should be safe - connection won't be closed during execution + result = holder.use_connection(lambda conn: (time.sleep(0.01), conn)[1]) + results.append(result) + except Exception as e: + errors.append(e) + + def close_connection(): + time.sleep(0.005) # Let use_connection start first + holder.clear() + + thread1 = threading.Thread(target=use_connection) + thread2 = threading.Thread(target=close_connection) + + thread1.start() + thread2.start() + thread1.join() + thread2.join() + + # Should have successfully used the connection + assert len(results) == 1 + assert len(errors) == 0 + assert results[0] == mock_conn