-
Notifications
You must be signed in to change notification settings - Fork 18
Expand file tree
/
Copy pathfederated_plugin.py
More file actions
306 lines (250 loc) · 13.4 KB
/
federated_plugin.py
File metadata and controls
306 lines (250 loc) · 13.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from html import unescape
from re import DOTALL, findall, search
from typing import TYPE_CHECKING, List
from urllib.parse import urlencode
from aws_advanced_python_wrapper.credentials_provider_factory import (
CredentialsProviderFactory, SamlCredentialsProviderFactory)
from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo
from aws_advanced_python_wrapper.utils.region_utils import RegionUtils
from aws_advanced_python_wrapper.utils.saml_utils import SamlUtils
if TYPE_CHECKING:
from boto3 import Session
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.utils.token_utils import TokenUtils
from datetime import datetime, timedelta
from typing import Callable, Dict, Optional, Set
import requests
from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
logger = Logger(__name__)
class FederatedAuthPlugin(Plugin):
_HTTPS_URL_PATTERN = r"^(https)://[-a-zA-Z0-9+&@#/%?=~_!:,.']*[-a-zA-Z0-9+&@#/%=~_']"
_SUBSCRIBED_METHODS: Set[str] = {"connect", "force_connect"}
_rds_utils: RdsUtils = RdsUtils()
_token_cache: Dict[str, TokenInfo] = {}
def __init__(self,
plugin_service: PluginService,
credentials_provider_factory: CredentialsProviderFactory,
token_utils: TokenUtils,
session: Optional[Session] = None):
self._plugin_service = plugin_service
self._credentials_provider_factory = credentials_provider_factory
self._session = session
self._region_utils = RegionUtils()
self._token_utils = token_utils
telemetry_factory = self._plugin_service.get_telemetry_factory()
self._fetch_token_counter = telemetry_factory.create_counter("federated.fetch_token.count")
self._cache_size_gauge = telemetry_factory.create_gauge("federated.token_cache.size", lambda: len(FederatedAuthPlugin._token_cache))
@property
def subscribed_methods(self) -> Set[str]:
return self._SUBSCRIBED_METHODS
def connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
connect_func: Callable) -> Connection:
return self._connect(host_info, props, connect_func)
def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callable) -> Connection:
SamlUtils.check_idp_credentials_with_fallback(props)
host = IamAuthUtils.get_iam_host(props, host_info)
port = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
region = self._region_utils.get_region(props, WrapperProperties.IAM_REGION.name, host, self._session)
if not region:
error_message = "RdsUtils.UnsupportedHostname"
logger.debug(error_message, host)
raise AwsWrapperError(Messages.get_formatted(error_message, host))
user = WrapperProperties.DB_USER.get(props)
cache_key: str = IamAuthUtils.get_cache_key(
user,
host,
port,
region
)
token_info: Optional[TokenInfo] = FederatedAuthPlugin._token_cache.get(cache_key)
if token_info is not None and not token_info.is_expired():
logger.debug("FederatedAuthPlugin.UseCachedToken", token_info.token)
self._plugin_service.driver_dialect.set_password(props, token_info.token)
else:
self._update_authentication_token(host_info, props, user, region, cache_key)
WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props))
try:
return connect_func()
except Exception as e:
if token_info is None or token_info.is_expired() or not self._plugin_service.is_login_exception(e):
raise e
self._update_authentication_token(host_info, props, user, region, cache_key)
try:
return connect_func()
except Exception as e:
error_message = "FederatedAuthPlugin.UnhandledException"
logger.debug(error_message, e)
raise AwsWrapperError(Messages.get_formatted(error_message, e)) from e
def force_connect(
self,
target_driver_func: Callable,
driver_dialect: DriverDialect,
host_info: HostInfo,
props: Properties,
is_initial_connection: bool,
force_connect_func: Callable) -> Connection:
return self._connect(host_info, props, force_connect_func)
def _update_authentication_token(self,
host_info: HostInfo,
props: Properties,
user: Optional[str],
region: str,
cache_key: str) -> None:
token_expiration_sec: int = WrapperProperties.IAM_TOKEN_EXPIRATION.get_int(props)
token_expiry: datetime = datetime.now() + timedelta(seconds=token_expiration_sec)
port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props)
self._fetch_token_counter.inc()
token: str = self._token_utils.generate_authentication_token(
self._plugin_service,
user,
host_info.host,
port,
region,
credentials,
self._session)
WrapperProperties.PASSWORD.set(props, token)
FederatedAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)
class FederatedAuthPluginFactory(PluginFactory):
def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin:
return FederatedAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props), RDSTokenUtils())
def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> AdfsCredentialsProviderFactory:
idp_name = WrapperProperties.IDP_NAME.get(props)
if idp_name is None or idp_name == "" or idp_name == "adfs":
return AdfsCredentialsProviderFactory(plugin_service, props)
raise AwsWrapperError(Messages.get_formatted("FederatedAuthPluginFactory.UnsupportedIdp", idp_name))
class AdfsCredentialsProviderFactory(SamlCredentialsProviderFactory):
_INPUT_TAG_PATTERN = r"<input(.+?)/>"
_FORM_ACTION_PATTERN = r"<form.*?action=\"([^\"]+)\""
_SAML_RESPONSE_PATTERN = r"\"SAMLResponse\" value=\"(?P<saml>[^\"]+)\""
_SAML_RESPONSE_PATTERN_GROUP = "saml"
_HTTPS_URL_PATTERN = r"^(https)://[-a-zA-Z0-9+&@#/%?=~_!:,.']*[-a-zA-Z0-9+&@#/%=~_']"
def __init__(self, plugin_service: PluginService, props: Properties):
self._plugin_service = plugin_service
self._properties = props
def get_saml_assertion(self, props: Properties):
try:
uri = self._get_sign_in_page_url(props)
sign_in_page_body: str = self._get_sign_in_page_body(uri, props)
action: str = self._get_form_action_from_html_body(sign_in_page_body)
if action != "" and action.startswith("/"):
uri = self._get_form_action_url(props, action)
params: Dict[str, str] = self._get_parameters_from_html_body(sign_in_page_body, props)
content: str = self._post_form_action_body(uri, params, props)
match = search(self._SAML_RESPONSE_PATTERN, content)
if not match:
error_message = "AdfsCredentialsProviderFactory.FailedLogin"
logger.debug(error_message, content)
raise AwsWrapperError(Messages.get_formatted(error_message, content))
# return SAML Response value
return match.group(self._SAML_RESPONSE_PATTERN_GROUP)
except IOError as e:
error_message = "FederatedAuthPlugin.UnhandledException"
logger.debug(error_message, e)
raise AwsWrapperError(Messages.get_formatted(error_message, e))
def _get_sign_in_page_body(self, url: str, props: Properties) -> str:
logger.debug("AdfsCredentialsProviderFactory.SignOnPageUrl", url)
SamlUtils.validate_url(url)
r = requests.get(url,
verify=WrapperProperties.SSL_SECURE.get_bool(props),
timeout=WrapperProperties.HTTP_REQUEST_TIMEOUT.get_int(props))
SamlUtils.validate_response(r)
return r.text
def _post_form_action_body(self, uri: str, parameters: Dict[str, str], props: Properties) -> str:
logger.debug("AdfsCredentialsProviderFactory.SignOnPagePostActionUrl", uri)
SamlUtils.validate_url(uri)
r = requests.post(uri, data=urlencode(parameters),
verify=WrapperProperties.SSL_SECURE.get_bool(props),
timeout=WrapperProperties.HTTP_REQUEST_TIMEOUT.get_int(props))
# Check HTTP Status Code is 2xx Success
SamlUtils.validate_response(r)
return r.text
def _get_sign_in_page_url(self, props) -> str:
idp_endpoint = WrapperProperties.IDP_ENDPOINT.get(props)
idp_port = WrapperProperties.IDP_PORT.get_int(props)
relaying_party_id = WrapperProperties.RELAYING_PARTY_ID.get(props)
url = f"https://{idp_endpoint}:{idp_port}/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp={relaying_party_id}"
if idp_endpoint is None or relaying_party_id is None:
error_message = "SamlUtils.InvalidHttpsUrl"
logger.debug(error_message, url)
raise AwsWrapperError(Messages.get_formatted(error_message, url))
return url
def _get_form_action_url(self, props: Properties, action: str) -> str:
idp_endpoint = WrapperProperties.IDP_ENDPOINT.get(props) if not None else ""
idp_port = WrapperProperties.IDP_PORT.get(props)
url = f"https://{idp_endpoint}:{idp_port}{action}"
if idp_endpoint is None:
error_message = "SamlUtils.InvalidHttpsUrl"
logger.debug(error_message, url)
raise AwsWrapperError(
Messages.get_formatted(error_message, url))
return url
def _get_input_tags_from_html(self, body: str) -> List[str]:
distinct_input_tags: List[str] = []
input_tags = findall(self._INPUT_TAG_PATTERN, body, DOTALL)
for input_tag in input_tags:
tag_name: str = self._get_value_by_key(input_tag, "name")
if tag_name != "" and tag_name not in distinct_input_tags:
distinct_input_tags.append(tag_name)
return input_tags
def _get_value_by_key(self, input: str, key: str) -> str:
key_value_pattern = r"(" + key + ")\\s*=\\s*\"(.*?)\""
match = search(key_value_pattern, input)
if match:
return unescape(match.group(2))
return ""
def _get_parameters_from_html_body(self, body: str, props: Properties) -> Dict[str, str]:
parameters: Dict[str, str] = {}
for input_tag in self._get_input_tags_from_html(body):
name: str = self._get_value_by_key(input_tag, "name")
name_lower: str = name.lower()
value: str = self._get_value_by_key(input_tag, "value")
if "username" in name_lower:
idp_user = WrapperProperties.IDP_USERNAME.get(props)
if idp_user is not None:
parameters[name] = idp_user
elif "authmethod" in name_lower:
if value != "":
parameters[name] = value
elif "password" in name_lower:
idp_password = WrapperProperties.IDP_PASSWORD.get(props)
if idp_password is not None:
parameters[name] = idp_password
elif name != "":
parameters[name] = value
return parameters
def _get_form_action_from_html_body(self, body: str) -> str:
match = search(self._FORM_ACTION_PATTERN, body)
if match:
return unescape(match.group(1))
return ""