Skip to content

Commit e592f20

Browse files
committed
chore: add type verification for custom handlers
1 parent 688e0fc commit e592f20

File tree

10 files changed

+43
-31
lines changed

10 files changed

+43
-31
lines changed

aws_advanced_python_wrapper/aws_credentials_manager.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
from typing import TYPE_CHECKING, Any, Callable, Optional
1919

2020
from boto3 import Session
21+
22+
from aws_advanced_python_wrapper.utils.messages import Messages
2123
from aws_advanced_python_wrapper.utils.properties import WrapperProperties
24+
2225
if TYPE_CHECKING:
2326
from aws_advanced_python_wrapper.hostinfo import HostInfo
2427
from aws_advanced_python_wrapper.utils.properties import Properties
@@ -32,6 +35,8 @@ class AwsCredentialsManager:
3235

3336
@staticmethod
3437
def set_custom_handler(custom_handler: Callable[[HostInfo, Properties], Optional[Session]]) -> None:
38+
if not callable(custom_handler):
39+
raise TypeError("custom_handler must be callable")
3540
with AwsCredentialsManager._lock:
3641
AwsCredentialsManager._handler = custom_handler
3742

@@ -52,27 +57,30 @@ def get_session(host_info: HostInfo, props: Properties, region: str) -> Session:
5257

5358
# Initialize session outside of lock.
5459
session = handler(host_info, props) if handler else None
55-
60+
61+
if session is not None and not isinstance(session, Session):
62+
raise TypeError(Messages.get_formatted("AwsCredentialsManager.InvalidHandler", type(session).__name__))
63+
5664
if session is None:
5765
profile_name = WrapperProperties.AWS_PROFILE.get(props)
5866
session = Session(profile_name=profile_name, region_name=region) if profile_name else Session(region_name=region)
59-
67+
6068
with AwsCredentialsManager._lock:
6169
if host_key not in AwsCredentialsManager._sessions:
6270
AwsCredentialsManager._sessions[host_key] = session
6371
return AwsCredentialsManager._sessions[host_key]
6472

6573
@staticmethod
66-
def get_client(service_name: str, session: Session, host: str, region: str):
74+
def get_client(service_name: str, session: Session, host: Optional[str], region: Optional[str]):
6775
key = f'{host}{region}{service_name}'
68-
76+
6977
with AwsCredentialsManager._lock:
7078
if key in AwsCredentialsManager._clients:
7179
return AwsCredentialsManager._clients[key]
7280

7381
# Initialize client outside of lock.
74-
client = session.client(service_name)
75-
82+
client = session.client(service_name) # type: ignore[call-overload]
83+
7684
with AwsCredentialsManager._lock:
7785
if key not in AwsCredentialsManager._clients:
7886
AwsCredentialsManager._clients[key] = client

aws_advanced_python_wrapper/cleanup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from aws_advanced_python_wrapper.aws_credentials_manager import AwsCredentialsManager
15+
from aws_advanced_python_wrapper.aws_credentials_manager import \
16+
AwsCredentialsManager
1617
from aws_advanced_python_wrapper.host_monitoring_plugin import \
1718
MonitoringThreadContainer
1819
from aws_advanced_python_wrapper.thread_pool_container import \

aws_advanced_python_wrapper/credentials_provider_factory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@
1616

1717
from typing import TYPE_CHECKING, Dict, Optional, Protocol
1818

19-
from aws_advanced_python_wrapper.aws_credentials_manager import AwsCredentialsManager
20-
from aws_advanced_python_wrapper.hostinfo import HostInfo
21-
2219
if TYPE_CHECKING:
20+
from aws_advanced_python_wrapper.hostinfo import HostInfo
2321
from aws_advanced_python_wrapper.utils.properties import Properties
2422

2523
from abc import abstractmethod
2624

25+
from aws_advanced_python_wrapper.aws_credentials_manager import \
26+
AwsCredentialsManager
2727
from aws_advanced_python_wrapper.utils.properties import WrapperProperties
2828

2929

aws_advanced_python_wrapper/federated_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
from typing import TYPE_CHECKING, List
2020
from urllib.parse import urlencode
2121

22-
import boto3
23-
24-
from aws_advanced_python_wrapper.aws_credentials_manager import AwsCredentialsManager
22+
from aws_advanced_python_wrapper.aws_credentials_manager import \
23+
AwsCredentialsManager
2524
from aws_advanced_python_wrapper.credentials_provider_factory import (
2625
CredentialsProviderFactory, SamlCredentialsProviderFactory)
2726
from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo
@@ -166,6 +165,7 @@ def release_resources() -> None:
166165
AwsCredentialsManager.release_resources()
167166
return None
168167

168+
169169
class FederatedAuthPluginFactory(PluginFactory):
170170
@staticmethod
171171
def get_instance(plugin_service: PluginService, props: Properties) -> Plugin:

aws_advanced_python_wrapper/iam_plugin.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616

1717
from typing import TYPE_CHECKING
1818

19-
import boto3
20-
21-
from aws_advanced_python_wrapper.aws_credentials_manager import AwsCredentialsManager
19+
from aws_advanced_python_wrapper.aws_credentials_manager import \
20+
AwsCredentialsManager
2221
from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo
2322
from aws_advanced_python_wrapper.utils.region_utils import RegionUtils
2423

aws_advanced_python_wrapper/okta_plugin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from re import search
2020
from typing import TYPE_CHECKING, Callable, Dict, Optional, Set
2121

22-
from aws_advanced_python_wrapper.aws_credentials_manager import AwsCredentialsManager
22+
from aws_advanced_python_wrapper.aws_credentials_manager import \
23+
AwsCredentialsManager
2324
from aws_advanced_python_wrapper.credentials_provider_factory import (
2425
CredentialsProviderFactory, SamlCredentialsProviderFactory)
2526
from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo

aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ AdfsCredentialsProviderFactory.SignOnPagePostActionUrl=[AdfsCredentialsProviderF
2626
AdfsCredentialsProviderFactory.SignOnPagePostActionRequestFailed=[AdfsCredentialsProviderFactory] ADFS SignOn Page POST action failed with HTTP status '{}', reason phrase '{}', and response '{}'
2727
AdfsCredentialsProviderFactory.SignOnPageUrl=[AdfsCredentialsProviderFactory] ADFS SignOn URL: '{}'
2828

29+
AwsCredentialsManager.InvalidHandler=[AwsCredentialsManager] Custom credentials provider set via AwsCredentialsManager.set_custom_handler must return a boto3.Session or None, got '{}'.
30+
2931
AwsSdk.UnsupportedRegion=[AwsSdk] Unsupported AWS region {}. For supported regions please read https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html
3032

3133
AwsSecretsManagerPlugin.ConnectException=[AwsSecretsManagerPlugin] Error occurred while opening a connection: {}

aws_advanced_python_wrapper/utils/iam_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
from datetime import datetime
1818
from typing import TYPE_CHECKING, Dict, Optional
1919

20-
import boto3
21-
22-
from aws_advanced_python_wrapper.aws_credentials_manager import AwsCredentialsManager
20+
from aws_advanced_python_wrapper.aws_credentials_manager import \
21+
AwsCredentialsManager
2322
from aws_advanced_python_wrapper.errors import AwsWrapperError
2423
from aws_advanced_python_wrapper.utils.log import Logger
2524
from aws_advanced_python_wrapper.utils.messages import Messages

docs/examples/PGAwsCredentialsManager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import psycopg
1919

2020
from aws_advanced_python_wrapper import AwsWrapperConnection, release_resources
21-
from aws_advanced_python_wrapper.aws_credentials_manager import AwsCredentialsManager
21+
from aws_advanced_python_wrapper.aws_credentials_manager import \
22+
AwsCredentialsManager
2223

2324

2425
def custom_credentials_handler(host_info, props):

tests/unit/test_aws_credentials_manager.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
import pytest
2020
from boto3 import Session
2121

22-
from aws_advanced_python_wrapper.aws_credentials_manager import AwsCredentialsManager
22+
from aws_advanced_python_wrapper.aws_credentials_manager import \
23+
AwsCredentialsManager
2324
from aws_advanced_python_wrapper.hostinfo import HostInfo
2425
from aws_advanced_python_wrapper.utils.atomic import AtomicInt
2526
from aws_advanced_python_wrapper.utils.properties import Properties
@@ -104,7 +105,7 @@ def test_get_session_different_regions(self, host_info, props, mocker):
104105
mock_session1.region_name = "us-east-1"
105106
mock_session2 = mocker.MagicMock(spec=Session)
106107
mock_session2.region_name = "us-west-2"
107-
108+
108109
mock_session_class = mocker.patch('aws_advanced_python_wrapper.aws_credentials_manager.Session')
109110
mock_session_class.side_effect = [mock_session1, mock_session2]
110111

@@ -121,7 +122,7 @@ def test_get_session_different_hosts(self, props, region, mocker):
121122

122123
mock_session1 = mocker.MagicMock(spec=Session)
123124
mock_session2 = mocker.MagicMock(spec=Session)
124-
125+
125126
mock_session_class = mocker.patch('aws_advanced_python_wrapper.aws_credentials_manager.Session')
126127
mock_session_class.side_effect = [mock_session1, mock_session2]
127128

@@ -144,7 +145,7 @@ def test_get_session_with_custom_handler(self, mock_session, host_info, props, r
144145
def test_reset_custom_handler(self, host_info, props, region, mocker):
145146
custom_session = mocker.MagicMock(spec=Session)
146147
custom_handler = mocker.MagicMock(return_value=custom_session)
147-
148+
148149
mock_default_session = mocker.MagicMock(spec=Session)
149150
mock_default_session.region_name = region
150151
mocker.patch('aws_advanced_python_wrapper.aws_credentials_manager.Session', return_value=mock_default_session)
@@ -200,10 +201,10 @@ def test_release_resources_clears_caches(self, host_info, props, region, mocker)
200201
mock_session1 = mocker.MagicMock(spec=Session)
201202
mock_session2 = mocker.MagicMock(spec=Session)
202203
mock_client = mocker.MagicMock()
203-
204+
204205
mock_session_class = mocker.patch('aws_advanced_python_wrapper.aws_credentials_manager.Session')
205206
mock_session_class.side_effect = [mock_session1, mock_session2]
206-
207+
207208
mock_session1.client.return_value = mock_client
208209

209210
session = AwsCredentialsManager.get_session(host_info, props, region)
@@ -218,7 +219,7 @@ def test_release_resources_clears_caches(self, host_info, props, region, mocker)
218219
def test_concurrent_get_session_same_host(self, mock_session, host_info, props, region, counter, concurrent_counter, num_threads):
219220
barrier = Barrier(num_threads)
220221
sessions = []
221-
222+
222223
def get_session_thread():
223224
barrier.wait()
224225
val = counter.get_and_increment()
@@ -275,13 +276,13 @@ def get_session_thread(thread_id):
275276
def test_concurrent_get_session_different_regions(self, num_threads, host_info, props, counter, concurrent_counter, regions, mocker):
276277
barrier = Barrier(num_threads)
277278
results = []
278-
279+
279280
# One session per region
280281
mock_sessions = {region: mocker.MagicMock(spec=Session) for region in regions}
281-
282+
282283
def session_factory(region_name=None, **kwargs):
283284
return mock_sessions[region_name]
284-
285+
285286
mocker.patch('aws_advanced_python_wrapper.aws_credentials_manager.Session', side_effect=session_factory)
286287

287288
def get_session_thread(thread_id):

0 commit comments

Comments
 (0)