Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ jobs:
markdown-link-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master
- uses: gaurav-nelson/github-action-markdown-link-check@v1
- uses: actions/checkout@v4
- uses: tcort/github-action-markdown-link-check@v1
with:
use-quiet-mode: 'yes'
folder-path: 'docs'
Expand Down
99 changes: 99 additions & 0 deletions aws_advanced_python_wrapper/aws_credentials_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from threading import Lock
from typing import TYPE_CHECKING, Any, Callable, Optional

from boto3 import Session

from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import WrapperProperties

if TYPE_CHECKING:
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.utils.properties import Properties


class AwsCredentialsManager:
_handler: Optional[Callable[[HostInfo, Properties], Optional[Session]]] = None
_lock = Lock()
_sessions: dict[str, Session] = {}
_clients: dict[str, Any] = {}

@staticmethod
def set_custom_handler(custom_handler: Callable[[HostInfo, Properties], Optional[Session]]) -> None:
if not callable(custom_handler):
raise TypeError("custom_handler must be callable")
with AwsCredentialsManager._lock:
AwsCredentialsManager._handler = custom_handler

@staticmethod
def reset_custom_handler() -> None:
with AwsCredentialsManager._lock:
AwsCredentialsManager._handler = None

@staticmethod
def get_session(host_info: HostInfo, props: Properties, region: str) -> Session:
host_key = f'{host_info.as_alias()}{region}'

handler = None
with AwsCredentialsManager._lock:
if host_key in AwsCredentialsManager._sessions:
return AwsCredentialsManager._sessions[host_key]
handler = AwsCredentialsManager._handler

# Initialize session outside of lock.
session = handler(host_info, props) if handler else None

if session is not None and not isinstance(session, type(Session())):
raise TypeError(Messages.get_formatted("AwsCredentialsManager.InvalidHandler", type(session).__name__))

if session is None:
profile_name = WrapperProperties.AWS_PROFILE.get(props)
session = Session(profile_name=profile_name, region_name=region) if profile_name else Session(region_name=region)

with AwsCredentialsManager._lock:
if host_key not in AwsCredentialsManager._sessions:
AwsCredentialsManager._sessions[host_key] = session
return AwsCredentialsManager._sessions[host_key]

@staticmethod
def get_client(service_name: str, session: Session, host: Optional[str], region: Optional[str], endpoint_url: Optional[str] = None):
key = f'{host}{region}{service_name}{endpoint_url}'

with AwsCredentialsManager._lock:
if key in AwsCredentialsManager._clients:
return AwsCredentialsManager._clients[key]

# Initialize client outside of lock.
if endpoint_url:
client = session.client(service_name=service_name, endpoint_url=endpoint_url) # type: ignore[call-overload]
else:
client = session.client(service_name=service_name) # type: ignore[call-overload]

with AwsCredentialsManager._lock:
if key not in AwsCredentialsManager._clients:
AwsCredentialsManager._clients[key] = client
return AwsCredentialsManager._clients[key]

@staticmethod
def release_resources() -> None:
with AwsCredentialsManager._lock:
for key, client in AwsCredentialsManager._clients.items():
client.close()
AwsCredentialsManager._clients.clear()
AwsCredentialsManager._sessions.clear()
return None
30 changes: 12 additions & 18 deletions aws_advanced_python_wrapper/aws_secrets_manager_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from types import SimpleNamespace
from typing import TYPE_CHECKING, Callable, Optional, Set, Tuple

import boto3
from botocore.exceptions import ClientError, EndpointConnectionError

from aws_advanced_python_wrapper.aws_credentials_manager import \
AwsCredentialsManager
from aws_advanced_python_wrapper.utils.cache_map import CacheMap

if TYPE_CHECKING:
Expand Down Expand Up @@ -86,7 +87,7 @@ def connect(
props: Properties,
is_initial_connection: bool,
connect_func: Callable) -> Connection:
return self._connect(props, connect_func)
return self._connect(host_info, props, connect_func)

def force_connect(
self,
Expand All @@ -96,16 +97,16 @@ def force_connect(
props: Properties,
is_initial_connection: bool,
force_connect_func: Callable) -> Connection:
return self._connect(props, force_connect_func)
return self._connect(host_info, props, force_connect_func)

def _connect(self, props: Properties, connect_func: Callable) -> Connection:
def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callable) -> Connection:
token_expiration_sec: int = WrapperProperties.SECRETS_MANAGER_EXPIRATION.get_int(props)
# if value is less than 0, default to one year
if token_expiration_sec < 0:
token_expiration_sec = AwsSecretsManagerPlugin._ONE_YEAR_IN_SECONDS
token_expiration_ns = token_expiration_sec * 1_000_000_000

secret_fetched: bool = self._update_secret(token_expiration_ns=token_expiration_ns)
secret_fetched: bool = self._update_secret(host_info, props, token_expiration_ns=token_expiration_ns)

try:
self._apply_secret_to_properties(props)
Expand All @@ -116,7 +117,7 @@ def _connect(self, props: Properties, connect_func: Callable) -> Connection:
raise AwsWrapperError(
Messages.get_formatted("AwsSecretsManagerPlugin.ConnectException", e)) from e

secret_fetched = self._update_secret(token_expiration_ns=token_expiration_ns, force_refetch=True)
secret_fetched = self._update_secret(host_info, props, token_expiration_ns=token_expiration_ns, force_refetch=True)

if secret_fetched:
try:
Expand All @@ -128,7 +129,7 @@ def _connect(self, props: Properties, connect_func: Callable) -> Connection:
unhandled_error)) from unhandled_error
raise AwsWrapperError(Messages.get_formatted("AwsSecretsManagerPlugin.FailedLogin", e)) from e

def _update_secret(self, token_expiration_ns: int, force_refetch: bool = False) -> bool:
def _update_secret(self, host_info: HostInfo, props: Properties, token_expiration_ns: int, force_refetch: bool = False) -> bool:
"""
Called to update credentials from the cache, or from the AWS Secrets Manager service.
:param token_expiration_ns: Expiration time in nanoseconds for secret stored in cache.
Expand All @@ -146,7 +147,7 @@ def _update_secret(self, token_expiration_ns: int, force_refetch: bool = False)
endpoint = self._secret_key[2]
if not self._secret or force_refetch:
try:
self._secret = self._fetch_latest_credentials()
self._secret = self._fetch_latest_credentials(host_info, props)
if self._secret:
AwsSecretsManagerPlugin._secrets_cache.put(self._secret_key, self._secret, token_expiration_ns)
fetched = True
Expand Down Expand Up @@ -177,26 +178,19 @@ def _update_secret(self, token_expiration_ns: int, force_refetch: bool = False)
if context is not None:
context.close_context()

def _fetch_latest_credentials(self):
def _fetch_latest_credentials(self, host_info: HostInfo, props: Properties):
"""
Fetches the current credentials from AWS Secrets Manager service.

:return: a Secret object containing the credentials fetched from the AWS Secrets Manager service.
"""
session = self._session if self._session else boto3.Session()

client = session.client(
'secretsmanager',
region_name=self._secret_key[1],
endpoint_url=self._secret_key[2],
)
session = AwsCredentialsManager.get_session(host_info, props, self._secret_key[1])
client = AwsCredentialsManager.get_client("secretsmanager", session, host_info.host, self._secret_key[1], self._secret_key[2])

secret = client.get_secret_value(
SecretId=self._secret_key[0],
)

client.close()

return loads(secret.get("SecretString"), object_hook=lambda d: SimpleNamespace(**d))

def _apply_secret_to_properties(self, properties: Properties):
Expand Down
3 changes: 3 additions & 0 deletions aws_advanced_python_wrapper/cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from aws_advanced_python_wrapper.aws_credentials_manager import \
AwsCredentialsManager
from aws_advanced_python_wrapper.host_monitoring_plugin import \
MonitoringThreadContainer
from aws_advanced_python_wrapper.thread_pool_container import \
Expand All @@ -22,3 +24,4 @@ def release_resources() -> None:
"""Release all global resources used by the wrapper."""
MonitoringThreadContainer.clean_up()
ThreadPoolContainer.release_resources()
AwsCredentialsManager.release_resources()
17 changes: 7 additions & 10 deletions aws_advanced_python_wrapper/credentials_provider_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,29 @@

from typing import TYPE_CHECKING, Dict, Optional, Protocol

import boto3

if TYPE_CHECKING:
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.utils.properties import Properties

from abc import abstractmethod

from aws_advanced_python_wrapper.aws_credentials_manager import \
AwsCredentialsManager
from aws_advanced_python_wrapper.utils.properties import WrapperProperties


class CredentialsProviderFactory(Protocol):
@abstractmethod
def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
def get_aws_credentials(self, region: str, props: Properties, host_info: HostInfo) -> Optional[Dict[str, str]]:
...


class SamlCredentialsProviderFactory(CredentialsProviderFactory):

def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
def get_aws_credentials(self, region: str, props: Properties, host_info: HostInfo) -> Optional[Dict[str, str]]:
saml_assertion: str = self.get_saml_assertion(props)
session = boto3.Session()

sts_client = session.client(
'sts',
region_name=region
)
session = AwsCredentialsManager.get_session(host_info, props, region)
sts_client = AwsCredentialsManager.get_client("sts", session, host_info.host, region)

response: Dict[str, Dict[str, str]] = sts_client.assume_role_with_saml(
RoleArn=WrapperProperties.IAM_ROLE_ARN.get(props),
Expand Down
2 changes: 1 addition & 1 deletion aws_advanced_python_wrapper/custom_endpoint_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class CustomEndpointPlugin(Plugin):
or removing an instance in the custom endpoint.
"""
_SUBSCRIBED_METHODS: ClassVar[Set[str]] = {DbApiMethod.CONNECT.method_name}
_CACHE_CLEANUP_RATE_NS: ClassVar[int] = 6 * 10 ^ 10 # 1 minute
_CACHE_CLEANUP_RATE_NS: ClassVar[int] = 60_000_000_000 # 1 minute
_monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, CustomEndpointMonitor]] = \
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_RATE_NS,
should_dispose_func=lambda _: True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, plugin_service: PluginService, props: Properties):
self._properties = props
self._host_response_time_service: HostResponseTimeService = \
HostResponseTimeService(plugin_service, props, WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props))
self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props) * 10 ^ 6
self._cache_expiration_nanos = WrapperProperties.RESPONSE_MEASUREMENT_INTERVAL_MS.get_int(props) * 1_000_000
self._random_host_selector = RandomHostSelector()
self._cached_fastest_response_host_by_role: CacheMap[str, HostInfo] = CacheMap()
self._hosts: Tuple[HostInfo, ...] = ()
Expand Down Expand Up @@ -278,8 +278,8 @@ def _open_connection(self):


class HostResponseTimeService:
_CACHE_EXPIRATION_NS: int = 6 * 10 ^ 11 # 10 minutes
_CACHE_CLEANUP_NS: int = 6 * 10 ^ 10 # 1 minute
_CACHE_EXPIRATION_NS: int = 10 * 60_000_000_000 # 10 minutes
_CACHE_CLEANUP_NS: int = 60_000_000_000 # 1 minute
_lock: Lock = Lock()
_monitoring_hosts: ClassVar[SlidingExpirationCacheWithCleanupThread[str, HostResponseTimeMonitor]] = \
SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NS,
Expand Down
20 changes: 12 additions & 8 deletions aws_advanced_python_wrapper/federated_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@

from __future__ import annotations

from copy import deepcopy
from html import unescape
from re import DOTALL, findall, search
from typing import TYPE_CHECKING, List
from urllib.parse import urlencode

from aws_advanced_python_wrapper.aws_credentials_manager import \
AwsCredentialsManager
from aws_advanced_python_wrapper.credentials_provider_factory import (
CredentialsProviderFactory, SamlCredentialsProviderFactory)
from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo
from aws_advanced_python_wrapper.utils.region_utils import RegionUtils
from aws_advanced_python_wrapper.utils.saml_utils import SamlUtils

if TYPE_CHECKING:
from boto3 import Session
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.pep249 import Connection
Expand Down Expand Up @@ -55,10 +57,9 @@ class FederatedAuthPlugin(Plugin):
_rds_utils: RdsUtils = RdsUtils()
_token_cache: Dict[str, TokenInfo] = {}

def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None):
def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory):
self._plugin_service = plugin_service
self._credentials_provider_factory = credentials_provider_factory
self._session = session

self._region_utils = RegionUtils()
telemetry_factory = self._plugin_service.get_telemetry_factory()
Expand Down Expand Up @@ -100,11 +101,13 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl

token_info: Optional[TokenInfo] = FederatedAuthPlugin._token_cache.get(cache_key)

token_host_info = deepcopy(host_info)
token_host_info.host = host
if token_info is not None and not token_info.is_expired():
logger.debug("FederatedAuthPlugin.UseCachedToken", token_info.token)
self._plugin_service.driver_dialect.set_password(props, token_info.token)
else:
self._update_authentication_token(host_info, props, user, region, cache_key)
self._update_authentication_token(token_host_info, props, user, region, cache_key)

WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props))

Expand All @@ -114,7 +117,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e):
raise e

self._update_authentication_token(host_info, props, user, region, cache_key)
self._update_authentication_token(token_host_info, props, user, region, cache_key)

try:
return connect_func()
Expand Down Expand Up @@ -142,18 +145,19 @@ def _update_authentication_token(self,
token_expiration_sec: int = WrapperProperties.IAM_TOKEN_EXPIRATION.get_int(props)
token_expiry: datetime = datetime.now() + timedelta(seconds=token_expiration_sec)
port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props)
credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props, host_info)

if self._fetch_token_counter is not None:
self._fetch_token_counter.inc()
session = AwsCredentialsManager.get_session(host_info, props, region)
token: str = IamAuthUtils.generate_authentication_token(
self._plugin_service,
user,
host_info.host,
port,
region,
credentials,
self._session)
session,
credentials)
WrapperProperties.PASSWORD.set(props, token)
FederatedAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)

Expand Down
Loading