Skip to content

Commit dfeedd6

Browse files
authored
feat: Add support to refresh federated auth access token (#46)
1 parent 9162302 commit dfeedd6

2 files changed

Lines changed: 417 additions & 9 deletions

File tree

deepnote_toolkit/sql/sql_execution.py

Lines changed: 103 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import base64
22
import contextlib
33
import json
4+
import logging
45
import re
56
import uuid
67
import warnings
8+
from typing import Any
79
from urllib.parse import quote
810

911
import google.oauth2.credentials
@@ -14,6 +16,7 @@
1416
from google.api_core.client_info import ClientInfo
1517
from google.cloud import bigquery
1618
from packaging.version import parse as parse_version
19+
from pydantic import BaseModel, ValidationError
1720
from sqlalchemy.engine import URL, create_engine, make_url
1821
from sqlalchemy.exc import ResourceClosedError
1922

@@ -33,6 +36,18 @@
3336
from deepnote_toolkit.sql.sql_utils import is_single_select_query
3437
from deepnote_toolkit.sql.url_utils import replace_user_pass_in_pg_url
3538

39+
logger = logging.getLogger(__name__)
40+
41+
42+
class IntegrationFederatedAuthParams(BaseModel):
43+
integrationId: str
44+
authContextToken: str
45+
46+
47+
class FederatedAuthResponseData(BaseModel):
48+
integrationType: str
49+
accessToken: str
50+
3651

3752
def compile_sql_query(
3853
skip_jinja_template_render,
@@ -242,11 +257,97 @@ def _generate_temporary_credentials(integration_id):
242257

243258
response = requests.post(url, timeout=10, headers=headers)
244259

260+
response.raise_for_status()
261+
245262
data = response.json()
246263

247264
return quote(data["username"]), quote(data["password"])
248265

249266

267+
def _get_federated_auth_credentials(
268+
integration_id: str, user_pod_auth_context_token: str
269+
) -> FederatedAuthResponseData:
270+
"""Get federated auth credentials for the given integration ID and user pod auth context token."""
271+
272+
url = get_absolute_userpod_api_url(
273+
f"integrations/federated-auth-token/{integration_id}"
274+
)
275+
276+
# Add project credentials in detached mode
277+
headers = get_project_auth_headers()
278+
headers["UserPodAuthContextToken"] = user_pod_auth_context_token
279+
280+
response = requests.post(url, timeout=10, headers=headers)
281+
282+
response.raise_for_status()
283+
284+
data = FederatedAuthResponseData.model_validate(response.json())
285+
286+
return data
287+
288+
289+
def _handle_iam_params(sql_alchemy_dict: dict[str, Any]) -> None:
290+
"""Apply IAM credentials to the connection URL in-place."""
291+
292+
if "iamParams" not in sql_alchemy_dict:
293+
return
294+
295+
integration_id = sql_alchemy_dict["iamParams"]["integrationId"]
296+
297+
temporary_username, temporary_password = _generate_temporary_credentials(
298+
integration_id
299+
)
300+
301+
sql_alchemy_dict["url"] = replace_user_pass_in_pg_url(
302+
sql_alchemy_dict["url"], temporary_username, temporary_password
303+
)
304+
305+
306+
def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None:
307+
"""Fetch and apply federated auth credentials to connection params in-place."""
308+
309+
if "federatedAuthParams" not in sql_alchemy_dict:
310+
return
311+
312+
try:
313+
federated_auth_params = IntegrationFederatedAuthParams.model_validate(
314+
sql_alchemy_dict["federatedAuthParams"]
315+
)
316+
except ValidationError:
317+
logger.exception("Invalid federated auth params, try updating toolkit version")
318+
return
319+
320+
federated_auth = _get_federated_auth_credentials(
321+
federated_auth_params.integrationId, federated_auth_params.authContextToken
322+
)
323+
324+
if federated_auth.integrationType == "trino":
325+
try:
326+
sql_alchemy_dict["params"]["connect_args"]["http_headers"][
327+
"Authorization"
328+
] = f"Bearer {federated_auth.accessToken}"
329+
except KeyError:
330+
logger.exception(
331+
"Invalid federated auth params, try updating toolkit version"
332+
)
333+
elif federated_auth.integrationType == "big-query":
334+
try:
335+
sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken
336+
except KeyError:
337+
logger.exception(
338+
"Invalid federated auth params, try updating toolkit version"
339+
)
340+
elif federated_auth.integrationType == "snowflake":
341+
logger.warning(
342+
"Snowflake federated auth is not supported yet, using the original connection URL"
343+
)
344+
else:
345+
logger.error(
346+
"Unsupported integration type: %s, try updating toolkit version",
347+
federated_auth.integrationType,
348+
)
349+
350+
250351
@contextlib.contextmanager
251352
def _create_sql_ssh_uri(ssh_enabled, sql_alchemy_dict):
252353
server = None
@@ -346,16 +447,9 @@ def _query_data_source(
346447
):
347448
sshEnabled = sql_alchemy_dict.get("ssh_options", {}).get("enabled", False)
348449

349-
if "iamParams" in sql_alchemy_dict:
350-
integration_id = sql_alchemy_dict["iamParams"]["integrationId"]
450+
_handle_iam_params(sql_alchemy_dict)
351451

352-
temporaryUsername, temporaryPassword = _generate_temporary_credentials(
353-
integration_id
354-
)
355-
356-
sql_alchemy_dict["url"] = replace_user_pass_in_pg_url(
357-
sql_alchemy_dict["url"], temporaryUsername, temporaryPassword
358-
)
452+
_handle_federated_auth_params(sql_alchemy_dict)
359453

360454
with _create_sql_ssh_uri(sshEnabled, sql_alchemy_dict) as url:
361455
if url is None:

0 commit comments

Comments
 (0)