Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -9084,9 +9084,9 @@ lodash.uniq@^4.5.0:
integrity sha512-xfBaXQd9ryd9dlSDvnvI0lvxfLJlYAZzXomUYzLKtUeOQvOP5piqAWuGtrhWeqaXK9hhoM/iyJc5AV+XfsX3HQ==

lodash@^4.15.0, lodash@^4.17.10, lodash@^4.17.20, lodash@^4.17.21:
version "4.17.21"
resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c"
integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==
version "4.17.23"
resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.23.tgz#f113b0378386103be4f6893388c73d0bde7f2c5a"
integrity sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==

longest-streak@^3.0.0:
version "3.1.0"
Expand Down
222 changes: 111 additions & 111 deletions superset-frontend/package-lock.json

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions superset-frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@
"yargs": "^17.7.2"
},
"devDependencies": {
"@applitools/eyes-storybook": "^3.63.4",
"@applitools/eyes-storybook": "^3.63.8",
"@babel/cli": "^7.28.6",
"@babel/compat-data": "^7.28.4",
"@babel/core": "^7.28.6",
Expand Down Expand Up @@ -332,7 +332,7 @@
"jest-websocket-mock": "^2.5.0",
"jsdom": "^27.4.0",
"lerna": "^8.2.3",
"lightningcss": "^1.30.2",
"lightningcss": "^1.31.1",
"mini-css-extract-plugin": "^2.10.0",
"open-cli": "^8.0.0",
"oxlint": "^1.41.0",
Expand Down
2 changes: 1 addition & 1 deletion superset-frontend/packages/generator-superset/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
},
"dependencies": {
"chalk": "^5.6.2",
"lodash-es": "^4.17.22",
"lodash-es": "^4.17.23",
"yeoman-generator": "^7.5.1",
"yosay": "^3.0.0"
},
Expand Down
20 changes: 14 additions & 6 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,9 +717,13 @@ def get_oauth2_token(
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()
response = (
requests.post(uri, data=req_body, timeout=timeout)
if config["request_content_type"] == "data"
else requests.post(uri, json=req_body, timeout=timeout)
)
response.raise_for_status()
return response.json()

@classmethod
def get_oauth2_fresh_token(
Expand All @@ -738,9 +742,13 @@ def get_oauth2_fresh_token(
"refresh_token": refresh_token,
"grant_type": "refresh_token",
}
if config["request_content_type"] == "data":
return requests.post(uri, data=req_body, timeout=timeout).json()
return requests.post(uri, json=req_body, timeout=timeout).json()
response = (
requests.post(uri, data=req_body, timeout=timeout)
if config["request_content_type"] == "data"
else requests.post(uri, json=req_body, timeout=timeout)
)
response.raise_for_status()
return response.json()

@classmethod
def get_allows_alias_in_select(
Expand Down
34 changes: 19 additions & 15 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,9 +896,7 @@ def get_all_table_names_in_schema(
)
}
except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.start_oauth2_dance()

self._handle_oauth2_error(ex)
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex

@cache_util.memoized_func(
Expand Down Expand Up @@ -933,9 +931,7 @@ def get_all_view_names_in_schema(
)
}
except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.start_oauth2_dance()

self._handle_oauth2_error(ex)
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex

@cache_util.memoized_func(
Expand Down Expand Up @@ -972,9 +968,7 @@ def get_all_materialized_view_names_in_schema(
)
}
except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.start_oauth2_dance()

self._handle_oauth2_error(ex)
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex

return set()
Expand Down Expand Up @@ -1003,9 +997,7 @@ def get_all_schema_names(self, *, catalog: str | None = None) -> set[str]:
with self.get_inspector(catalog=catalog) as inspector:
return self.db_engine_spec.get_schema_names(inspector)
except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.start_oauth2_dance()

self._handle_oauth2_error(ex)
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex

@cache_util.memoized_func(
Expand All @@ -1022,9 +1014,7 @@ def get_all_catalog_names(self) -> set[str]:
with self.get_inspector() as inspector:
return self.db_engine_spec.get_catalog_names(self, inspector)
except Exception as ex:
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.start_oauth2_dance()

self._handle_oauth2_error(ex)
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex

@property
Expand Down Expand Up @@ -1261,6 +1251,10 @@ def get_oauth2_config(self) -> OAuth2ClientConfig | None:
if oauth2_client_info := encrypted_extra.get("oauth2_client_info"):
schema = OAuth2ClientConfigSchema()
client_config = schema.load(oauth2_client_info)
if "request_content_type" not in oauth2_client_info:
client_config["request_content_type"] = (
self.db_engine_spec.oauth2_token_request_type
)
return cast(OAuth2ClientConfig, client_config)

return self.db_engine_spec.get_oauth2_config()
Expand All @@ -1275,6 +1269,16 @@ def start_oauth2_dance(self) -> None:
"""
return self.db_engine_spec.start_oauth2_dance(self)

def _handle_oauth2_error(self, ex: Exception) -> None:
"""
Handle exceptions that may require OAuth2 authentication.

If OAuth2 is enabled and the exception indicates that OAuth2 is needed,
starts the OAuth2 dance.
"""
if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
self.start_oauth2_dance()

def purge_oauth2_tokens(self) -> None:
"""
Delete all OAuth2 tokens associated with this database.
Expand Down
5 changes: 4 additions & 1 deletion superset/utils/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ class OAuth2ClientConfigSchema(Schema):
scope = fields.String(required=True)
redirect_uri = fields.String(
required=False,
load_default=lambda: url_for("DatabaseRestApi.oauth2", _external=True),
load_default=lambda: app.config.get(
"DATABASE_OAUTH2_REDIRECT_URI",
url_for("DatabaseRestApi.oauth2", _external=True),
),
)
authorization_request_uri = fields.String(required=True)
token_request_uri = fields.String(required=True)
Expand Down
81 changes: 81 additions & 0 deletions tests/unit_tests/models/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,34 @@ def test_get_oauth2_config(app_context: None) -> None:

assert database.get_oauth2_config() is None

database.encrypted_extra = json.dumps(oauth2_client_info)
assert database.get_oauth2_config() == {
"id": "my_client_id",
"secret": "my_client_secret",
"authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize",
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:USERADMIN",
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
"request_content_type": "data", # Default value from BaseEngineSpec
}


def test_get_oauth2_config_token_request_type_from_db_engine_specs(
mocker: MockerFixture, app_context: None
) -> None:
"""
Test that DB Engine Spec overrides for ``oauth2_token_request_type`` are respected.
"""
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)
mocker.patch.object(
database.db_engine_spec,
"oauth2_token_request_type",
"json",
)

database.encrypted_extra = json.dumps(oauth2_client_info)
assert database.get_oauth2_config() == {
"id": "my_client_id",
Expand All @@ -672,6 +700,59 @@ def test_get_oauth2_config(app_context: None) -> None:
}


def test_get_oauth2_config_custom_token_request_type_extra(app_context: None) -> None:
"""
Test passing a custom ``token_request_type`` via ``encrypted_extra``
takes precedence.
"""
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)
custom_oauth2_client_info = {
"oauth2_client_info": {
**oauth2_client_info["oauth2_client_info"],
"request_content_type": "json",
}
}

database.encrypted_extra = json.dumps(custom_oauth2_client_info)
assert database.get_oauth2_config() == {
"id": "my_client_id",
"secret": "my_client_secret",
"authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize",
"token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request",
"scope": "refresh_token session:role:USERADMIN",
"redirect_uri": "http://example.com/api/v1/database/oauth2/",
"request_content_type": "json",
}


def test_get_oauth2_config_redirect_uri_from_config(
mocker: MockerFixture,
app_context: None,
) -> None:
"""
Test that ``DATABASE_OAUTH2_REDIRECT_URI`` config takes precedence over
url_for default.
"""
custom_redirect_uri = "https://custom.example.com/oauth/callback"
mocker.patch.dict(
"superset.utils.oauth2.app.config",
{"DATABASE_OAUTH2_REDIRECT_URI": custom_redirect_uri},
)
database = Database(
database_name="db",
sqlalchemy_uri="postgresql://user:password@host:5432/examples",
)
database.encrypted_extra = json.dumps(oauth2_client_info)

config = database.get_oauth2_config()

assert config is not None
assert config["redirect_uri"] == custom_redirect_uri


def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
Expand Down
Loading