1919from unittest .mock import patch
2020
2121import pytest
22+ from boto3 import Session
2223
24+ from aws_advanced_python_wrapper .aws_credentials_manager import \
25+ AwsCredentialsManager
2326from aws_advanced_python_wrapper .federated_plugin import FederatedAuthPlugin
2427from aws_advanced_python_wrapper .hostinfo import HostInfo
2528from aws_advanced_python_wrapper .iam_plugin import TokenInfo
4043@pytest .fixture (autouse = True )
4144def clear_cache ():
4245 _token_cache .clear ()
46+ FederatedAuthPlugin .release_resources ()
4347
4448
4549@pytest .fixture
4650def mock_session (mocker ):
47- return mocker .MagicMock ()
51+ return mocker .MagicMock (spec = Session )
4852
4953
5054@pytest .fixture
@@ -91,6 +95,13 @@ def mock_default_behavior(mock_session, mock_client, mock_func, mock_connection,
9195 "SecretAccessKey" : "test-secret-access" ,
9296 "SessionToken" : "test-session-token" }
9397
98+ def custom_handler (host_info : HostInfo , props : Properties ) -> Session :
99+ return mock_session
100+
101+ AwsCredentialsManager .set_custom_handler (custom_handler )
102+ yield
103+ AwsCredentialsManager .reset_custom_handler ()
104+
94105
95106@patch ("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache" , _token_cache )
96107def test_pg_connect_valid_token_in_cache (mocker , mock_plugin_service , mock_session , mock_func , mock_client , mock_dialect ):
@@ -129,7 +140,7 @@ def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_fu
129140 initial_token = TokenInfo (_TEST_TOKEN , datetime .now () - timedelta (minutes = 5 ))
130141 _token_cache [_PG_CACHE_KEY ] = initial_token
131142
132- target_plugin : FederatedAuthPlugin = FederatedAuthPlugin (mock_plugin_service , mock_credentials_provider_factory , mock_session )
143+ target_plugin : FederatedAuthPlugin = FederatedAuthPlugin (mock_plugin_service , mock_credentials_provider_factory )
133144
134145 target_plugin .connect (
135146 target_driver_func = mocker .MagicMock (),
@@ -154,7 +165,7 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m
154165 test_props : Properties = Properties ({"plugins" : "federated_auth" , "user" : "postgresqlUser" , "idp_username" : "user" , "idp_password" : "password" })
155166 WrapperProperties .DB_USER .set (test_props , _DB_USER )
156167
157- target_plugin : FederatedAuthPlugin = FederatedAuthPlugin (mock_plugin_service , mock_credentials_provider_factory , mock_session )
168+ target_plugin : FederatedAuthPlugin = FederatedAuthPlugin (mock_plugin_service , mock_credentials_provider_factory )
158169
159170 target_plugin .connect (
160171 target_driver_func = mocker .MagicMock (),
@@ -183,8 +194,7 @@ def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_sess
183194 exception_message = "generic exception"
184195 mock_func .side_effect = Exception (exception_message )
185196
186- target_plugin : FederatedAuthPlugin = FederatedAuthPlugin (mock_plugin_service , mock_credentials_provider_factory ,
187- mock_session )
197+ target_plugin : FederatedAuthPlugin = FederatedAuthPlugin (mock_plugin_service , mock_credentials_provider_factory )
188198 with pytest .raises (Exception ) as e_info :
189199 target_plugin .connect (
190200 target_driver_func = mocker .MagicMock (),
@@ -229,11 +239,11 @@ def test_connect_with_specified_iam_host_port_region(mocker,
229239
230240 mock_client .generate_db_auth_token .return_value = f"{ _TEST_TOKEN } :{ expected_region } "
231241
232- target_plugin : FederatedAuthPlugin = FederatedAuthPlugin (mock_plugin_service , mock_credentials_provider_factory , mock_session )
242+ target_plugin : FederatedAuthPlugin = FederatedAuthPlugin (mock_plugin_service , mock_credentials_provider_factory )
233243 target_plugin .connect (
234244 target_driver_func = mocker .MagicMock (),
235245 driver_dialect = mock_dialect ,
236- host_info = HostInfo (expected_host ),
246+ host_info = HostInfo ("foo.com" ),
237247 props = properties ,
238248 is_initial_connection = False ,
239249 connect_func = mock_func )
0 commit comments