diff --git a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py index 927bd3c4..904e4941 100644 --- a/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py +++ b/aws_advanced_python_wrapper/sql_alchemy_connection_provider.py @@ -51,9 +51,9 @@ class SqlAlchemyPooledConnectionProvider(ConnectionProvider, CanReleaseResources "weighted_random": WeightedRandomHostSelector(), "highest_weight": HighestWeightHostSelector()} _rds_utils: ClassVar[RdsUtils] = RdsUtils() - _database_pools: ClassVar[SlidingExpirationCache[PoolKey, QueuePool]] = SlidingExpirationCache( - should_dispose_func=lambda queue_pool: queue_pool.checkedout() == 0, - item_disposal_func=lambda queue_pool: queue_pool.dispose() + _database_pools: ClassVar[SlidingExpirationCache[PoolKey, Tuple[QueuePool, Properties]]] = SlidingExpirationCache( + should_dispose_func=lambda pool_pair: pool_pair[0].checkedout() == 0, + item_disposal_func=lambda pool_pair: pool_pair[0].dispose() ) def __init__( @@ -119,7 +119,8 @@ def _num_connections(self, host_info: HostInfo) -> int: num_connections = 0 for pool_key, cache_item in SqlAlchemyPooledConnectionProvider._database_pools.items(): if pool_key.url == host_info.url: - num_connections += cache_item.item.checkedout() + queue_pool, _ = cache_item.item + num_connections += queue_pool.checkedout() return num_connections def connect( @@ -129,15 +130,22 @@ def connect( database_dialect: DatabaseDialect, host_info: HostInfo, props: Properties): - queue_pool: Optional[QueuePool] = SqlAlchemyPooledConnectionProvider._database_pools.compute_if_absent( + db_pool: Optional[Tuple[QueuePool, Properties]] = SqlAlchemyPooledConnectionProvider._database_pools.compute_if_absent( PoolKey(host_info.url, self._get_extra_key(host_info, props)), lambda _: self._create_pool(target_func, driver_dialect, database_dialect, host_info, props), SqlAlchemyPooledConnectionProvider._POOL_EXPIRATION_CHECK_NS ) - if queue_pool is None: + if db_pool is None: raise AwsWrapperError(Messages.get_formatted("SqlAlchemyPooledConnectionProvider.PoolNone", host_info.url)) + queue_pool, creator_props = db_pool + + # Update the password in the creator's captured properties so new pooled connections use the latest credentials + password = WrapperProperties.PASSWORD.get(props) + if password is not None: + creator_props[WrapperProperties.PASSWORD.name] = password + return queue_pool.connect() # The pool key should always be retrieved using this method, because the username @@ -163,7 +171,7 @@ def _create_pool( prepared_properties = driver_dialect.prepare_connect_info(host_info, props) database_dialect.prepare_conn_props(prepared_properties) kwargs["creator"] = self._get_connection_func(target_func, prepared_properties) - return self._create_sql_alchemy_pool(**kwargs) + return self._create_sql_alchemy_pool(**kwargs), prepared_properties def _get_connection_func(self, target_connect_func: Callable, props: Properties): return lambda: target_connect_func(**props) @@ -174,7 +182,8 @@ def _create_sql_alchemy_pool(self, **kwargs): def release_resources(self): for _, cache_item in SqlAlchemyPooledConnectionProvider._database_pools.items(): try: - cache_item.item.dispose() + queue_pool, _ = cache_item.item + queue_pool.dispose() except Exception: # Swallow exception, connections may already be dead pass diff --git a/aws_advanced_python_wrapper/utils/pg_exception_handler.py b/aws_advanced_python_wrapper/utils/pg_exception_handler.py index 8b2dee28..714798c7 100644 --- a/aws_advanced_python_wrapper/utils/pg_exception_handler.py +++ b/aws_advanced_python_wrapper/utils/pg_exception_handler.py @@ -75,7 +75,10 @@ def _is_network_error(self, error: Optional[BaseException], sql_state: Optional[ return False # Check the error message if this is a generic error error_msg: str = error.args[0] - return any(error_msg.startswith(msg) for msg in self._NETWORK_ERROR_MESSAGES) + is_network_error: bool = any(error_msg.startswith(msg) for msg in self._NETWORK_ERROR_MESSAGES) + # PAM errors may be nested in the connection error, double check here to avoid false positives + # Example nested error: 'connection failed: ...: FATAL: password authentication failed for user' + return is_network_error and not any(msg in error_msg for msg in self._ACCESS_ERROR_MESSAGES) return False diff --git a/tests/unit/test_exception_handling.py b/tests/unit/test_exception_handling.py index eebe9753..f186315d 100644 --- a/tests/unit/test_exception_handling.py +++ b/tests/unit/test_exception_handling.py @@ -162,6 +162,72 @@ def test_is_login_exception_with_nested_non_login_error_pg(pg_handler): assert pg_handler.is_login_exception(error=wrapper_error) is False +def test_nested_pam_error_in_connection_failed_is_not_network_exception_pg(pg_handler): + error_msg = ( + 'connection failed: connection to server at "", port 5432 failed: ' + 'FATAL: PAM authentication failed for user ""' + ) + wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg)) + + assert pg_handler.is_network_exception(error=wrapper_error) is False + + +def test_nested_password_auth_error_in_connection_failed_is_not_network_exception_pg(pg_handler): + error_msg = ( + 'connection failed: connection to server at "", port 5432 failed: ' + 'FATAL: password authentication failed for user ""' + ) + wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg)) + + assert pg_handler.is_network_exception(error=wrapper_error) is False + + +def test_nested_pam_error_deeply_wrapped_is_not_network_exception_pg(pg_handler): + error_msg = ( + 'connection failed: connection to server at "", port 5432 failed: ' + 'FATAL: password authentication failed for user ""' + ) + wrapper_error = AwsWrapperError( + "[IamAuthPlugin] Error occurred while opening a connection", + AwsWrapperError("Inner wrapper", OperationalError(error_msg))) + + assert pg_handler.is_network_exception(error=wrapper_error) is False + + +def test_nested_pam_error_is_login_exception_pg(pg_handler): + error_msg = ( + 'connection failed: connection to server at "", port 5432 failed: ' + 'FATAL: PAM authentication failed for user ""' + ) + wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg)) + + assert pg_handler.is_login_exception(error=wrapper_error) is True + + +def test_nested_password_auth_error_is_login_exception_pg(pg_handler): + error_msg = ( + 'connection failed: connection to server at "", port 5432 failed: ' + 'FATAL: password authentication failed for user ""' + ) + wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg)) + + assert pg_handler.is_login_exception(error=wrapper_error) is True + + +def test_pure_connection_failed_is_network_exception_pg(pg_handler): + error_msg = 'connection failed: connection to server at "", port 5432 failed: Connection refused' + wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg)) + + assert pg_handler.is_network_exception(error=wrapper_error) is True + + +def test_pure_connection_failed_is_not_login_exception_pg(pg_handler): + error_msg = 'connection failed: connection to server at "", port 5432 failed: Connection refused' + wrapper_error = AwsWrapperError("Error occurred while opening a connection", OperationalError(error_msg)) + + assert pg_handler.is_login_exception(error=wrapper_error) is False + + def test_is_read_only_exception_with_nested_aws_wrapper_error_mysql(mysql_handler): class MockReadOnlyError(Exception): def __init__(self): diff --git a/tests/unit/test_sql_alchemy_pooled_connection_provider.py b/tests/unit/test_sql_alchemy_pooled_connection_provider.py index 56e82795..5b87ef0d 100644 --- a/tests/unit/test_sql_alchemy_pooled_connection_provider.py +++ b/tests/unit/test_sql_alchemy_pooled_connection_provider.py @@ -63,6 +63,15 @@ def clear_cache(): SqlAlchemyPooledConnectionProvider._database_pools.clear() +@pytest.fixture +def mock_dialects(mocker): + mock_driver_dialect = mocker.MagicMock() + mock_database_dialect = mocker.MagicMock() + mock_driver_dialect.prepare_connect_info.side_effect = lambda host, props: Properties(props.copy()) + mock_database_dialect.prepare_conn_props.return_value = None + return mock_driver_dialect, mock_database_dialect + + def test_connect__default_mapping__default_pool_configuration(provider, host_info, mocker, mock_conn, mock_pool): expected_urls = {host_info.url} expected_keys = [PoolKey(host_info.url, "user1")] @@ -100,6 +109,69 @@ def test_connect__custom_configuration_and_mapping(host_info, mocker, mock_conn, mock_pool_initializer_func.assert_called_with(creator=mock_pool_connection_func, pool_size=10) +def test_connect__updates_password_in_cached_pool_creator_props(host_info, mocker, mock_dialects): + captured_password: list = [] + props = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "TOKEN_1"}) + + def fake_target_connect(**kwargs): + captured_password.append(kwargs.get("password")) + return mocker.MagicMock(spec=psycopg.Connection) + + provider = SqlAlchemyPooledConnectionProvider( + pool_configurator=lambda _, __: {"pool_size": 0, "max_overflow": 2} + ) + mock_driver_dialect, mock_database_dialect = mock_dialects + + # Create a cached pool + conn_1 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props) + assert conn_1 is not None + + # Rotate password + props[WrapperProperties.PASSWORD.name] = "TOKEN_2" + + conn_2 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props) + assert conn_2 is not None + + assert captured_password == ["TOKEN_1", "TOKEN_2"] + + +def test_connect__password_update_different_pool_keys(host_info, mocker, mock_dialects): + captured_password_user1: list = [] + captured_password_user2: list = [] + props_user1 = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "TOKEN_1"}) + props_user2 = Properties({WrapperProperties.USER.name: "user2", WrapperProperties.PASSWORD.name: "TOKEN_2"}) + + def fake_target_connect(**kwargs): + user = kwargs.get("user") + pwd = kwargs.get("password") + if user == "user1": + captured_password_user1.append(pwd) + else: + captured_password_user2.append(pwd) + return mocker.MagicMock(spec=psycopg.Connection) + + provider = SqlAlchemyPooledConnectionProvider( + pool_configurator=lambda _, __: {"pool_size": 0, "max_overflow": 3} + ) + mock_driver_dialect, mock_database_dialect = mock_dialects + + conn1 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props_user1) + conn2 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props_user2) + assert conn1 is not None + assert conn2 is not None + + assert captured_password_user1 == ["TOKEN_1"] + assert captured_password_user2 == ["TOKEN_2"] + + # Rotate password for user 1 + props_user1[WrapperProperties.PASSWORD.name] = "TOKEN_3" + conn3 = provider.connect(fake_target_connect, mock_driver_dialect, mock_database_dialect, host_info, props_user1) + assert conn3 is not None + + assert captured_password_user1 == ["TOKEN_1", "TOKEN_3"] + assert captured_password_user2 == ["TOKEN_2"] + + def test_accepts_host_info(provider): instance_url = "instance-1.XYZ.us-east-2.rds.amazonaws.com" instance_host_info = HostInfo(instance_url) @@ -115,18 +187,24 @@ def test_least_connections_strategy(provider, mock_pool): writer = HostInfo("writer.XYZ.us-east-2.rds.amazonaws.com") reader_1 = HostInfo("reader-1.XYZ.us-east-2.rds.amazonaws.com", role=HostRole.READER) reader_2 = HostInfo("reader-2.XYZ.us-east-2.rds.amazonaws.com", role=HostRole.READER) - hosts = [writer, reader_1, reader_2] + hosts = (writer, reader_1, reader_2) props = Properties({WrapperProperties.USER.name: "user1", WrapperProperties.PASSWORD.name: "password"}) # Create cache with 1 pool to reader_url_1_connection and 2 pools to reader_url_2_connections. # Each pool holds 1 connection. test_database_pools = SlidingExpirationCache() test_database_pools.compute_if_absent( - PoolKey(reader_1.url, "user1"), lambda _: mock_pool, 10 * 60_000_000_000) + PoolKey(reader_1.url, "user1"), + lambda _: (mock_pool, Properties()), + 10 * 60_000_000_000) test_database_pools.compute_if_absent( - PoolKey(reader_2.url, "user1"), lambda _: mock_pool, 10 * 60_000_000_000) + PoolKey(reader_2.url, "user1"), + lambda _: (mock_pool, Properties()), + 10 * 60_000_000_000) test_database_pools.compute_if_absent( - PoolKey(reader_2.url, "user2"), lambda _: mock_pool, 10 * 60_000_000_000) + PoolKey(reader_2.url, "user2"), + lambda _: (mock_pool, Properties()), + 10 * 60_000_000_000) result = provider.get_host_info_by_strategy(hosts, HostRole.READER, "least_connections", props) assert reader_1 == result @@ -135,15 +213,21 @@ def test_least_connections_strategy(provider, mock_pool): def test_least_connections_strategy__no_hosts_matching_role(provider): props = Properties() with pytest.raises(AwsWrapperError): - provider.get_host_info_by_strategy([HostInfo("writer")], HostRole.READER, "least_connections", props) + provider.get_host_info_by_strategy((HostInfo("writer"),), HostRole.READER, "least_connections", props) def test_release_resources(provider, mocker): pool1 = mocker.MagicMock() pool2 = mocker.MagicMock() test_database_pools = SlidingExpirationCache() - test_database_pools.compute_if_absent(PoolKey("url1", "user1"), lambda _: pool1, 60_000_000_000) - test_database_pools.compute_if_absent(PoolKey("url1", "user2"), lambda _: pool2, 60_000_000_000) + test_database_pools.compute_if_absent( + PoolKey("url1", "user1"), + lambda _: (pool1, Properties()), + 60_000_000_000) + test_database_pools.compute_if_absent( + PoolKey("url1", "user2"), + lambda _: (pool2, Properties()), + 60_000_000_000) SqlAlchemyPooledConnectionProvider._database_pools = test_database_pools provider.release_resources()