|
1 | 1 | import base64 |
2 | 2 | import contextlib |
3 | 3 | import json |
| 4 | +import logging |
4 | 5 | import re |
5 | 6 | import uuid |
6 | 7 | import warnings |
| 8 | +from typing import Any |
7 | 9 | from urllib.parse import quote |
8 | 10 |
|
9 | 11 | import google.oauth2.credentials |
|
14 | 16 | from google.api_core.client_info import ClientInfo |
15 | 17 | from google.cloud import bigquery |
16 | 18 | from packaging.version import parse as parse_version |
| 19 | +from pydantic import BaseModel, ValidationError |
17 | 20 | from sqlalchemy.engine import URL, create_engine, make_url |
18 | 21 | from sqlalchemy.exc import ResourceClosedError |
19 | 22 |
|
|
33 | 36 | from deepnote_toolkit.sql.sql_utils import is_single_select_query |
34 | 37 | from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url |
35 | 38 |
|
| 39 | +logger = logging.getLogger(__name__) |
| 40 | + |
| 41 | + |
| 42 | +class IntegrationFederatedAuthParams(BaseModel): |
| 43 | + integrationId: str |
| 44 | + authContextToken: str |
| 45 | + |
| 46 | + |
| 47 | +class FederatedAuthResponseData(BaseModel): |
| 48 | + integrationType: str |
| 49 | + accessToken: str |
| 50 | + |
36 | 51 |
|
37 | 52 | def compile_sql_query( |
38 | 53 | skip_jinja_template_render, |
@@ -242,11 +257,97 @@ def _generate_temporary_credentials(integration_id): |
242 | 257 |
|
243 | 258 | response = requests.post(url, timeout=10, headers=headers) |
244 | 259 |
|
| 260 | + response.raise_for_status() |
| 261 | + |
245 | 262 | data = response.json() |
246 | 263 |
|
247 | 264 | return quote(data["username"]), quote(data["password"]) |
248 | 265 |
|
249 | 266 |
|
| 267 | +def _get_federated_auth_credentials( |
| 268 | + integration_id: str, user_pod_auth_context_token: str |
| 269 | +) -> FederatedAuthResponseData: |
| 270 | + """Get federated auth credentials for the given integration ID and user pod auth context token.""" |
| 271 | + |
| 272 | + url = get_absolute_userpod_api_url( |
| 273 | + f"integrations/federated-auth-token/{integration_id}" |
| 274 | + ) |
| 275 | + |
| 276 | + # Add project credentials in detached mode |
| 277 | + headers = get_project_auth_headers() |
| 278 | + headers["UserPodAuthContextToken"] = user_pod_auth_context_token |
| 279 | + |
| 280 | + response = requests.post(url, timeout=10, headers=headers) |
| 281 | + |
| 282 | + response.raise_for_status() |
| 283 | + |
| 284 | + data = FederatedAuthResponseData.model_validate(response.json()) |
| 285 | + |
| 286 | + return data |
| 287 | + |
| 288 | + |
| 289 | +def _handle_iam_params(sql_alchemy_dict: dict[str, Any]) -> None: |
| 290 | + """Apply IAM credentials to the connection URL in-place.""" |
| 291 | + |
| 292 | + if "iamParams" not in sql_alchemy_dict: |
| 293 | + return |
| 294 | + |
| 295 | + integration_id = sql_alchemy_dict["iamParams"]["integrationId"] |
| 296 | + |
| 297 | + temporary_username, temporary_password = _generate_temporary_credentials( |
| 298 | + integration_id |
| 299 | + ) |
| 300 | + |
| 301 | + sql_alchemy_dict["url"] = replace_user_pass_in_pg_url( |
| 302 | + sql_alchemy_dict["url"], temporary_username, temporary_password |
| 303 | + ) |
| 304 | + |
| 305 | + |
| 306 | +def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None: |
| 307 | + """Fetch and apply federated auth credentials to connection params in-place.""" |
| 308 | + |
| 309 | + if "federatedAuthParams" not in sql_alchemy_dict: |
| 310 | + return |
| 311 | + |
| 312 | + try: |
| 313 | + federated_auth_params = IntegrationFederatedAuthParams.model_validate( |
| 314 | + sql_alchemy_dict["federatedAuthParams"] |
| 315 | + ) |
| 316 | + except ValidationError: |
| 317 | + logger.exception("Invalid federated auth params, try updating toolkit version") |
| 318 | + return |
| 319 | + |
| 320 | + federated_auth = _get_federated_auth_credentials( |
| 321 | + federated_auth_params.integrationId, federated_auth_params.authContextToken |
| 322 | + ) |
| 323 | + |
| 324 | + if federated_auth.integrationType == "trino": |
| 325 | + try: |
| 326 | + sql_alchemy_dict["params"]["connect_args"]["http_headers"][ |
| 327 | + "Authorization" |
| 328 | + ] = f"Bearer {federated_auth.accessToken}" |
| 329 | + except KeyError: |
| 330 | + logger.exception( |
| 331 | + "Invalid federated auth params, try updating toolkit version" |
| 332 | + ) |
| 333 | + elif federated_auth.integrationType == "big-query": |
| 334 | + try: |
| 335 | + sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken |
| 336 | + except KeyError: |
| 337 | + logger.exception( |
| 338 | + "Invalid federated auth params, try updating toolkit version" |
| 339 | + ) |
| 340 | + elif federated_auth.integrationType == "snowflake": |
| 341 | + logger.warning( |
| 342 | + "Snowflake federated auth is not supported yet, using the original connection URL" |
| 343 | + ) |
| 344 | + else: |
| 345 | + logger.error( |
| 346 | + "Unsupported integration type: %s, try updating toolkit version", |
| 347 | + federated_auth.integrationType, |
| 348 | + ) |
| 349 | + |
| 350 | + |
250 | 351 | @contextlib.contextmanager |
251 | 352 | def _create_sql_ssh_uri(ssh_enabled, sql_alchemy_dict): |
252 | 353 | server = None |
@@ -346,16 +447,9 @@ def _query_data_source( |
346 | 447 | ): |
347 | 448 | sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False) |
348 | 449 |
|
349 | | - if "iamParams" in sql_alchemy_dict: |
350 | | - integration_id = sql_alchemy_dict["iamParams"]["integrationId"] |
| 450 | + _handle_iam_params(sql_alchemy_dict) |
351 | 451 |
|
352 | | - temporaryUsername, temporaryPassword = _generate_temporary_credentials( |
353 | | - integration_id |
354 | | - ) |
355 | | - |
356 | | - sql_alchemy_dict["url"] = replace_user_pass_in_pg_url( |
357 | | - sql_alchemy_dict["url"], temporaryUsername, temporaryPassword |
358 | | - ) |
| 452 | + _handle_federated_auth_params(sql_alchemy_dict) |
359 | 453 |
|
360 | 454 | with _create_sql_ssh_uri(sshEnabled, sql_alchemy_dict) as url: |
361 | 455 | if url is None: |
|
0 commit comments