Skip to content

Commit dece3e3

Browse files
committed
test: fix unit tests
1 parent e592f20 commit dece3e3

9 files changed

Lines changed: 521 additions & 57 deletions

File tree

aws_advanced_python_wrapper/aws_credentials_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def get_session(host_info: HostInfo, props: Properties, region: str) -> Session:
5858
# Initialize session outside of lock.
5959
session = handler(host_info, props) if handler else None
6060

61-
if session is not None and not isinstance(session, Session):
61+
if session is not None and not isinstance(session, type(Session())):
6262
raise TypeError(Messages.get_formatted("AwsCredentialsManager.InvalidHandler", type(session).__name__))
6363

6464
if session is None:

aws_advanced_python_wrapper/federated_plugin.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from copy import deepcopy
1718
from html import unescape
1819
from re import DOTALL, findall, search
1920
from typing import TYPE_CHECKING, List
@@ -28,7 +29,6 @@
2829
from aws_advanced_python_wrapper.utils.saml_utils import SamlUtils
2930

3031
if TYPE_CHECKING:
31-
from boto3 import Session
3232
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
3333
from aws_advanced_python_wrapper.hostinfo import HostInfo
3434
from aws_advanced_python_wrapper.pep249 import Connection
@@ -57,7 +57,7 @@ class FederatedAuthPlugin(Plugin):
5757
_rds_utils: RdsUtils = RdsUtils()
5858
_token_cache: Dict[str, TokenInfo] = {}
5959

60-
def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None):
60+
def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory):
6161
self._plugin_service = plugin_service
6262
self._credentials_provider_factory = credentials_provider_factory
6363

@@ -101,11 +101,13 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
101101

102102
token_info: Optional[TokenInfo] = FederatedAuthPlugin._token_cache.get(cache_key)
103103

104+
token_host_info = deepcopy(host_info)
105+
token_host_info.host = host
104106
if token_info is not None and not token_info.is_expired():
105107
logger.debug("FederatedAuthPlugin.UseCachedToken", token_info.token)
106108
self._plugin_service.driver_dialect.set_password(props, token_info.token)
107109
else:
108-
self._update_authentication_token(host_info, props, user, region, cache_key)
110+
self._update_authentication_token(token_host_info, props, user, region, cache_key)
109111

110112
WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props))
111113

@@ -115,7 +117,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
115117
if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e):
116118
raise e
117119

118-
self._update_authentication_token(host_info, props, user, region, cache_key)
120+
self._update_authentication_token(token_host_info, props, user, region, cache_key)
119121

120122
try:
121123
return connect_func()

aws_advanced_python_wrapper/iam_plugin.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from copy import deepcopy
1718
from typing import TYPE_CHECKING
1819

1920
from aws_advanced_python_wrapper.aws_credentials_manager import \
@@ -22,14 +23,13 @@
2223
from aws_advanced_python_wrapper.utils.region_utils import RegionUtils
2324

2425
if TYPE_CHECKING:
25-
from boto3 import Session
2626
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
2727
from aws_advanced_python_wrapper.hostinfo import HostInfo
2828
from aws_advanced_python_wrapper.pep249 import Connection
2929
from aws_advanced_python_wrapper.plugin_service import PluginService
3030

3131
from datetime import datetime, timedelta
32-
from typing import Callable, Dict, Optional, Set
32+
from typing import Callable, Dict, Set
3333

3434
from aws_advanced_python_wrapper.errors import AwsWrapperError
3535
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
@@ -51,7 +51,7 @@ class IamAuthPlugin(Plugin):
5151
_rds_utils: RdsUtils = RdsUtils()
5252
_token_cache: Dict[str, TokenInfo] = {}
5353

54-
def __init__(self, plugin_service: PluginService, session: Optional[Session] = None):
54+
def __init__(self, plugin_service: PluginService):
5555
self._plugin_service = plugin_service
5656

5757
self._region_utils = RegionUtils()
@@ -106,11 +106,13 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
106106
if self._fetch_token_counter is not None:
107107
self._fetch_token_counter.inc()
108108

109+
session_host_info = deepcopy(host_info)
110+
session_host_info.host = host
109111
session = AwsCredentialsManager.get_session(host_info, props, region)
110112
token: str = IamAuthUtils.generate_authentication_token(
111113
self._plugin_service,
112114
user,
113-
host_info.host,
115+
host,
114116
port,
115117
region,
116118
session)
@@ -132,7 +134,10 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
132134
token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec)
133135
if self._fetch_token_counter is not None:
134136
self._fetch_token_counter.inc()
135-
session = AwsCredentialsManager.get_session(host_info, props, region)
137+
138+
session_host_info = deepcopy(host_info)
139+
session_host_info.host = host
140+
session = AwsCredentialsManager.get_session(session_host_info, props, region)
136141
token = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, session)
137142
self._plugin_service.driver_dialect.set_password(props, token)
138143
IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)

aws_advanced_python_wrapper/okta_plugin.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from copy import deepcopy
1718
from datetime import datetime, timedelta
1819
from html import unescape
1920
from re import search
@@ -28,7 +29,6 @@
2829
from aws_advanced_python_wrapper.utils.saml_utils import SamlUtils
2930

3031
if TYPE_CHECKING:
31-
from boto3 import Session
3232
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
3333
from aws_advanced_python_wrapper.hostinfo import HostInfo
3434
from aws_advanced_python_wrapper.pep249 import Connection
@@ -53,7 +53,7 @@ class OktaAuthPlugin(Plugin):
5353
_rds_utils: RdsUtils = RdsUtils()
5454
_token_cache: Dict[str, TokenInfo] = {}
5555

56-
def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None):
56+
def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory):
5757
self._plugin_service = plugin_service
5858
self._credentials_provider_factory = credentials_provider_factory
5959

@@ -97,11 +97,13 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
9797

9898
token_info: Optional[TokenInfo] = OktaAuthPlugin._token_cache.get(cache_key)
9999

100+
token_host_info = deepcopy(host_info)
101+
token_host_info.host = host
100102
if token_info is not None and not token_info.is_expired():
101103
logger.debug("OktaAuthPlugin.UseCachedToken", token_info.token)
102104
self._plugin_service.driver_dialect.set_password(props, token_info.token)
103105
else:
104-
self._update_authentication_token(host_info, props, user, region, cache_key)
106+
self._update_authentication_token(token_host_info, props, user, region, cache_key)
105107

106108
WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props))
107109

@@ -111,7 +113,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
111113
if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e):
112114
raise e
113115

114-
self._update_authentication_token(host_info, props, user, region, cache_key)
116+
self._update_authentication_token(token_host_info, props, user, region, cache_key)
115117

116118
try:
117119
return connect_func()
@@ -154,6 +156,12 @@ def _update_authentication_token(self,
154156
WrapperProperties.PASSWORD.set(props, token)
155157
OktaAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)
156158

159+
@staticmethod
160+
def release_resources() -> None:
161+
OktaAuthPlugin._token_cache.clear()
162+
AwsCredentialsManager.release_resources()
163+
return None
164+
157165

158166
class OktaCredentialsProviderFactory(SamlCredentialsProviderFactory):
159167
_SAML_RESPONSE_PATTERN = r"\"SAMLResponse\" .* value=\"(?P<saml>[^\"]+)\""
@@ -227,11 +235,6 @@ def get_saml_assertion(self, props: Properties):
227235
logger.debug(error_message, e)
228236
raise AwsWrapperError(Messages.get_formatted(error_message, e))
229237

230-
@staticmethod
231-
def release_resources() -> None:
232-
AwsCredentialsManager.release_resources()
233-
return None
234-
235238

236239
class OktaAuthPluginFactory(PluginFactory):
237240
@staticmethod

poetry.lock

Lines changed: 432 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ types_aws_xray_sdk = "^2.13.0"
3232
opentelemetry-api = "^1.22.0"
3333
opentelemetry-sdk = "^1.22.0"
3434
requests = "^2.32.2"
35+
boto3-stubs = "~=1.37.38"
3536

3637
[tool.poetry.group.dev.dependencies]
3738
mypy = "^1.9.0"

tests/unit/test_federated_auth_plugin.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
from unittest.mock import patch
2020

2121
import pytest
22+
from boto3 import Session
2223

24+
from aws_advanced_python_wrapper.aws_credentials_manager import \
25+
AwsCredentialsManager
2326
from aws_advanced_python_wrapper.federated_plugin import FederatedAuthPlugin
2427
from aws_advanced_python_wrapper.hostinfo import HostInfo
2528
from aws_advanced_python_wrapper.iam_plugin import TokenInfo
@@ -40,11 +43,12 @@
4043
@pytest.fixture(autouse=True)
4144
def clear_cache():
4245
_token_cache.clear()
46+
FederatedAuthPlugin.release_resources()
4347

4448

4549
@pytest.fixture
4650
def mock_session(mocker):
47-
return mocker.MagicMock()
51+
return mocker.MagicMock(spec=Session)
4852

4953

5054
@pytest.fixture
@@ -91,6 +95,13 @@ def mock_default_behavior(mock_session, mock_client, mock_func, mock_connection,
9195
"SecretAccessKey": "test-secret-access",
9296
"SessionToken": "test-session-token"}
9397

98+
def custom_handler(host_info: HostInfo, props: Properties) -> Session:
99+
return mock_session
100+
101+
AwsCredentialsManager.set_custom_handler(custom_handler)
102+
yield
103+
AwsCredentialsManager.reset_custom_handler()
104+
94105

95106
@patch("aws_advanced_python_wrapper.federated_plugin.FederatedAuthPlugin._token_cache", _token_cache)
96107
def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect):
@@ -129,7 +140,7 @@ def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_fu
129140
initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5))
130141
_token_cache[_PG_CACHE_KEY] = initial_token
131142

132-
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session)
143+
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory)
133144

134145
target_plugin.connect(
135146
target_driver_func=mocker.MagicMock(),
@@ -154,7 +165,7 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m
154165
test_props: Properties = Properties({"plugins": "federated_auth", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"})
155166
WrapperProperties.DB_USER.set(test_props, _DB_USER)
156167

157-
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session)
168+
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory)
158169

159170
target_plugin.connect(
160171
target_driver_func=mocker.MagicMock(),
@@ -183,8 +194,7 @@ def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_sess
183194
exception_message = "generic exception"
184195
mock_func.side_effect = Exception(exception_message)
185196

186-
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory,
187-
mock_session)
197+
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory)
188198
with pytest.raises(Exception) as e_info:
189199
target_plugin.connect(
190200
target_driver_func=mocker.MagicMock(),
@@ -229,11 +239,11 @@ def test_connect_with_specified_iam_host_port_region(mocker,
229239

230240
mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{expected_region}"
231241

232-
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session)
242+
target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory)
233243
target_plugin.connect(
234244
target_driver_func=mocker.MagicMock(),
235245
driver_dialect=mock_dialect,
236-
host_info=HostInfo(expected_host),
246+
host_info=HostInfo("foo.com"),
237247
props=properties,
238248
is_initial_connection=False,
239249
connect_func=mock_func)

0 commit comments

Comments
 (0)