diff --git a/app/auth.py b/app/auth.py index d27d605..ee251cd 100644 --- a/app/auth.py +++ b/app/auth.py @@ -86,7 +86,35 @@ async def websocket_authenticate(websocket: WebSocket) -> str | None: return None -async def exchange_token_for_provider( +async def exchange_token(user_token: str, url: str) -> str: + """ + Retrieve the exchanged token for accessing an external backend. This is done by exchanging the + user's token for a platform-specific token using the configured token provider. + + :param url: The URL of the backend for which to exchange the token. This URL should be + configured in the BACKEND_CONFIG environment variable. + :return: The bearer token as a string. + """ + + provider = settings.backend_auth_config[url].token_provider + token_prefix = settings.backend_auth_config[url].token_prefix + + if not provider or not token_prefix: + raise ValueError( + f"Backend '{url}' must define 'token_provider' and 'token_prefix'" + ) + + platform_token = await _exchange_token_for_provider( + initial_token=user_token, provider=provider + ) + return ( + f"{token_prefix}/{platform_token['access_token']}" + if token_prefix + else platform_token["access_token"] + ) + + +async def _exchange_token_for_provider( initial_token: str, provider: str ) -> Dict[str, Any]: """ diff --git a/app/config/openeo/settings.py b/app/config/schemas.py similarity index 67% rename from app/config/openeo/settings.py rename to app/config/schemas.py index aedf49b..d19bd11 100644 --- a/app/config/openeo/settings.py +++ b/app/config/schemas.py @@ -3,13 +3,13 @@ from pydantic import BaseModel -class OpenEOAuthMethod(str, Enum): +class AuthMethod(str, Enum): CLIENT_CREDENTIALS = "CLIENT_CREDENTIALS" USER_CREDENTIALS = "USER_CREDENTIALS" -class OpenEOBackendConfig(BaseModel): - auth_method: OpenEOAuthMethod = OpenEOAuthMethod.USER_CREDENTIALS +class BackendAuthConfig(BaseModel): + auth_method: AuthMethod = AuthMethod.USER_CREDENTIALS client_credentials: Optional[str] = None token_provider: Optional[str] = None token_prefix: Optional[str] = None diff --git a/app/config/settings.py b/app/config/settings.py index 3c02ab6..0458aba 100644 --- a/app/config/settings.py +++ b/app/config/settings.py @@ -4,7 +4,7 @@ from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict -from app.config.openeo.settings import OpenEOAuthMethod, OpenEOBackendConfig +from app.config.schemas import AuthMethod, BackendAuthConfig class Settings(BaseSettings): @@ -40,29 +40,28 @@ class Settings(BaseSettings): default="", json_schema_extra={"env": "KEYCLOAK_CLIENT_SECRET"} ) - # openEO Settings - openeo_backends: str | None = Field( - default="", json_schema_extra={"env": "OPENEO_BACKENDS"} + # Backend auth configuration + backends: str | None = Field( + default="", json_schema_extra={"env": "BACKENDS"} ) + backend_auth_config: Dict[str, BackendAuthConfig] = Field(default_factory=dict) - openeo_backend_config: Dict[str, OpenEOBackendConfig] = Field(default_factory=dict) - - def load_openeo_backends_from_env(self): + def load_backends_auth_config(self): """ Populate self.backends from BACKENDS_JSON if provided, otherwise keep defaults. BACKENDS_JSON should be a JSON object keyed by hostname with BackendConfig-like values. """ required_fields = [] - if self.openeo_backends: + if self.backends: try: - raw = json.loads(self.openeo_backends) + raw = json.loads(self.backends) for host, cfg in raw.items(): - backend = OpenEOBackendConfig(**cfg) + backend = BackendAuthConfig(**cfg) - if backend.auth_method == OpenEOAuthMethod.CLIENT_CREDENTIALS: + if backend.auth_method == AuthMethod.CLIENT_CREDENTIALS: required_fields = ["client_credentials"] - elif backend.auth_method == OpenEOAuthMethod.USER_CREDENTIALS: + elif backend.auth_method == AuthMethod.USER_CREDENTIALS: required_fields = ["token_provider"] for field in required_fields: @@ -71,11 +70,11 @@ def load_openeo_backends_from_env(self): f"Backend '{host}' must define '{field}' when " f"OPENEO_AUTH_METHOD={backend.auth_method}" ) - self.openeo_backend_config[host] = OpenEOBackendConfig(**cfg) + self.backend_auth_config[host] = BackendAuthConfig(**cfg) except Exception: # Fall back or raise as appropriate raise settings = Settings() -settings.load_openeo_backends_from_env() +settings.load_backends_auth_config() diff --git a/app/platforms/implementations/openeo.py b/app/platforms/implementations/openeo.py index 0b7c3a6..8c4e1ba 100644 --- a/app/platforms/implementations/openeo.py +++ b/app/platforms/implementations/openeo.py @@ -9,8 +9,9 @@ from loguru import logger from stac_pydantic import Collection -from app.auth import exchange_token_for_provider -from app.config.settings import OpenEOAuthMethod, settings +from app.auth import exchange_token +from app.config.schemas import AuthMethod +from app.config.settings import settings from app.platforms.base import BaseProcessingPlatform from app.platforms.dispatcher import register_platform from app.schemas.enum import OutputFormatEnum, ProcessingStatusEnum, ProcessTypeEnum @@ -58,28 +59,6 @@ def _connection_expired(self, connection: openeo.Connection) -> bool: logger.warning("No JWT bearer token found in connection.") return True - async def _get_bearer_token(self, user_token: str, url: str) -> str: - """ - Retrieve the bearer token for the OpenEO backend. This is done by exchanging the user's - token for a platform-specific token using the configured token provider. - - :param url: The URL of the OpenEO backend. - :return: The bearer token as a string. - """ - - provider = settings.openeo_backend_config[url].token_provider - token_prefix = settings.openeo_backend_config[url].token_prefix - - if not provider or not token_prefix: - raise ValueError( - f"Backend '{url}' must define 'token_provider' and 'token_prefix'" - ) - - platform_token = await exchange_token_for_provider( - initial_token=user_token, provider=provider - ) - return f"{token_prefix}/{platform_token['access_token']}" - async def _authenticate_user( self, user_token: str, url: str, connection: openeo.Connection ) -> openeo.Connection: @@ -88,19 +67,19 @@ async def _authenticate_user( This method can be used to set the user's token for the connection. """ - if url not in settings.openeo_backend_config: + if url not in settings.backend_auth_config: raise ValueError(f"No OpenEO backend configuration found for URL: {url}") if ( - settings.openeo_backend_config[url].auth_method - == OpenEOAuthMethod.USER_CREDENTIALS + settings.backend_auth_config[url].auth_method + == AuthMethod.USER_CREDENTIALS ): logger.debug("Using user credentials for OpenEO connection authentication") - bearer_token = await self._get_bearer_token(user_token, url) + bearer_token = await exchange_token(user_token=user_token, url=url) connection.authenticate_bearer_token(bearer_token=bearer_token) elif ( - settings.openeo_backend_config[url].auth_method - == OpenEOAuthMethod.CLIENT_CREDENTIALS + settings.backend_auth_config[url].auth_method + == AuthMethod.CLIENT_CREDENTIALS ): logger.debug( "Using client credentials for OpenEO connection authentication" @@ -115,7 +94,7 @@ async def _authenticate_user( else: raise ValueError( "Unsupported OpenEO authentication method: " - f"{settings.openeo_backend_config[url].auth_method}" + f"{settings.backend_auth_config[url].auth_method}" ) return connection @@ -145,7 +124,7 @@ def _get_client_credentials(self, url: str) -> tuple[str, str, str]: :param url: The URL of the OpenEO backend. :return: A tuple containing provider ID, client ID, and client secret. """ - credentials_str = settings.openeo_backend_config[url].client_credentials + credentials_str = settings.backend_auth_config[url].client_credentials if not credentials_str: raise ValueError( diff --git a/tests/platforms/test_openeo_platform.py b/tests/platforms/test_openeo_platform.py index 1447e37..512af98 100644 --- a/tests/platforms/test_openeo_platform.py +++ b/tests/platforms/test_openeo_platform.py @@ -7,8 +7,8 @@ import pytest import requests +from app.config.schemas import AuthMethod, BackendAuthConfig from app.config.settings import settings -from app.config.openeo.settings import OpenEOBackendConfig, OpenEOAuthMethod from app.platforms.implementations.openeo import ( OpenEOPlatform, ) @@ -39,14 +39,14 @@ def platform(): @pytest.fixture(autouse=True) def mock_env(monkeypatch): - settings.openeo_backend_config["https://openeo.dataspace.copernicus.eu"] = ( - OpenEOBackendConfig( + settings.backend_auth_config["https://openeo.dataspace.copernicus.eu"] = ( + BackendAuthConfig( client_credentials="cdse-provider123/cdse-client123/cdse-secret123", token_prefix="cdse-prefix", token_provider="cdse-provider", ) ) - settings.openeo_backend_config["https://openeo.vito.be"] = OpenEOBackendConfig( + settings.backend_auth_config["https://openeo.vito.be"] = BackendAuthConfig( client_credentials="vito-provider123/vito-client123/vito-secret123", token_prefix="vito-prefix", token_provider="vito-provider", @@ -251,29 +251,27 @@ def test_connection_expired_exception(mock_decode, platform): @pytest.mark.asyncio @patch( - "app.platforms.implementations.openeo.exchange_token_for_provider", + "app.platforms.implementations.openeo.exchange_token", new_callable=AsyncMock, ) async def test_authenticate_user_with_user_credentials(mock_exchange, platform): url = "https://openeo.vito.be" # enable user credentials path - settings.openeo_backend_config[url].auth_method = OpenEOAuthMethod.USER_CREDENTIALS + settings.backend_auth_config[url].auth_method = AuthMethod.USER_CREDENTIALS # set up a fake connection with the expected method conn = MagicMock() conn.authenticate_bearer_token = MagicMock() # prepare the exchange mock to return the exchanged token - mock_exchange.return_value = {"access_token": "exchanged-token"} + mock_exchange.return_value = "vito-prefix/exchanged-token" # choose a url that maps via BACKEND_PROVIDER_ID_MAP (hostname only) returned = await platform._authenticate_user("user-token", url, conn) # assertions - mock_exchange.assert_awaited_once_with( - initial_token="user-token", provider="vito-provider" - ) + mock_exchange.assert_awaited_once_with(user_token="user-token", url=url) conn.authenticate_bearer_token.assert_called_once_with( bearer_token="vito-prefix/exchanged-token" ) @@ -282,7 +280,7 @@ async def test_authenticate_user_with_user_credentials(mock_exchange, platform): @pytest.mark.asyncio @patch( - "app.platforms.implementations.openeo.exchange_token_for_provider", + "app.platforms.implementations.openeo.exchange_token", new_callable=AsyncMock, ) async def test_authenticate_user_with_client_credentials( @@ -290,9 +288,7 @@ async def test_authenticate_user_with_client_credentials( ): url = "https://openeo.vito.be" # disable user credentials path -> use client credentials - settings.openeo_backend_config[url].auth_method = ( - OpenEOAuthMethod.CLIENT_CREDENTIALS - ) + settings.backend_auth_config[url].auth_method = AuthMethod.CLIENT_CREDENTIALS # prepare fake connection and spy method conn = MagicMock() @@ -314,7 +310,7 @@ async def test_authenticate_user_with_client_credentials( @pytest.mark.asyncio @patch( - "app.platforms.implementations.openeo.exchange_token_for_provider", + "app.platforms.implementations.openeo.exchange_token", new_callable=AsyncMock, ) async def test_authenticate_user_config_missing_url( @@ -337,7 +333,7 @@ async def test_authenticate_user_config_missing_url( @pytest.mark.asyncio @patch( - "app.platforms.implementations.openeo.exchange_token_for_provider", + "app.platforms.implementations.openeo.exchange_token", new_callable=AsyncMock, ) async def test_authenticate_user_config_unsupported_method( @@ -345,16 +341,14 @@ async def test_authenticate_user_config_unsupported_method( ): url = "https://openeo.vito.be" # disable user credentials path -> use client credentials - settings.openeo_backend_config[url].auth_method = "FOOBAR" + settings.backend_auth_config[url].auth_method = "FOOBAR" # prepare fake connection and spy method conn = MagicMock() conn.authenticate_oidc_client_credentials = MagicMock() # ensure the exchange mock exists but is not awaited - with pytest.raises( - ValueError, match="Unsupported OpenEO authentication method" - ): + with pytest.raises(ValueError, match="Unsupported OpenEO authentication method"): await platform._authenticate_user("user-token", url, conn) mock_exchange.assert_not_awaited() @@ -362,7 +356,7 @@ async def test_authenticate_user_config_unsupported_method( @pytest.mark.asyncio @patch( - "app.platforms.implementations.openeo.exchange_token_for_provider", + "app.platforms.implementations.openeo.exchange_token", new_callable=AsyncMock, ) async def test_authenticate_user_config_missing_credentials( @@ -370,10 +364,8 @@ async def test_authenticate_user_config_missing_credentials( ): url = "https://openeo.vito.be" # disable user credentials path -> use client credentials - settings.openeo_backend_config[url].auth_method = ( - OpenEOAuthMethod.CLIENT_CREDENTIALS - ) - settings.openeo_backend_config[url].client_credentials = None + settings.backend_auth_config[url].auth_method = AuthMethod.CLIENT_CREDENTIALS + settings.backend_auth_config[url].client_credentials = None # prepare fake connection and spy method conn = MagicMock() @@ -390,7 +382,7 @@ async def test_authenticate_user_config_missing_credentials( @pytest.mark.asyncio @patch( - "app.platforms.implementations.openeo.exchange_token_for_provider", + "app.platforms.implementations.openeo.exchange_token", new_callable=AsyncMock, ) async def test_authenticate_user_config_format_issue_credentials( @@ -398,10 +390,8 @@ async def test_authenticate_user_config_format_issue_credentials( ): url = "https://openeo.vito.be" # disable user credentials path -> use client credentials - settings.openeo_backend_config[url].auth_method = ( - OpenEOAuthMethod.CLIENT_CREDENTIALS - ) - settings.openeo_backend_config[url].client_credentials = "foobar" + settings.backend_auth_config[url].auth_method = AuthMethod.CLIENT_CREDENTIALS + settings.backend_auth_config[url].client_credentials = "foobar" # prepare fake connection and spy method conn = MagicMock() @@ -414,52 +404,6 @@ async def test_authenticate_user_config_format_issue_credentials( mock_exchange.assert_not_awaited() -@pytest.mark.asyncio -@patch( - "app.platforms.implementations.openeo.exchange_token_for_provider", - new_callable=AsyncMock, -) -async def test_authenticate_user_config_missing_provider( - mock_exchange, monkeypatch, platform -): - url = "https://openeo.vito.be" - # disable user credentials path -> use client credentials - settings.openeo_backend_config[url].token_provider = None - - # prepare fake connection and spy method - conn = MagicMock() - conn.authenticate_oidc_client_credentials = MagicMock() - - # ensure the exchange mock exists but is not awaited - with pytest.raises(ValueError, match="must define"): - await platform._authenticate_user("user-token", url, conn) - - mock_exchange.assert_not_awaited() - - -@pytest.mark.asyncio -@patch( - "app.platforms.implementations.openeo.exchange_token_for_provider", - new_callable=AsyncMock, -) -async def test_authenticate_user_config_missing_prefix( - mock_exchange, monkeypatch, platform -): - url = "https://openeo.vito.be" - # disable user credentials path -> use client credentials - settings.openeo_backend_config[url].token_prefix = None - - # prepare fake connection and spy method - conn = MagicMock() - conn.authenticate_oidc_client_credentials = MagicMock() - - # ensure the exchange mock exists but is not awaited - with pytest.raises(ValueError, match="must define"): - await platform._authenticate_user("user-token", url, conn) - - mock_exchange.assert_not_awaited() - - @pytest.mark.asyncio @patch("app.platforms.implementations.openeo.openeo.connect") @patch.object(OpenEOPlatform, "_authenticate_user", new_callable=AsyncMock) diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..5fb9805 --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,276 @@ +import pytest +from unittest.mock import MagicMock, patch, AsyncMock +import httpx +from fastapi import status + +from app.auth import exchange_token, _exchange_token_for_provider +from app.config.settings import settings +from app.config.schemas import BackendAuthConfig, AuthMethod +from app.error import AuthException + + +# Tests for exchange_token function +@pytest.mark.asyncio +async def test_exchange_token_missing_provider(): + + url = "https://openeo.vito.be" + original_config = settings.backend_auth_config.get(url) + + try: + # Create a config without token_provider + settings.backend_auth_config[url] = BackendAuthConfig( + auth_method=AuthMethod.USER_CREDENTIALS, + token_provider=None, + token_prefix="Bearer", + ) + + with pytest.raises(ValueError, match="must define"): + await exchange_token("user-token", url) + finally: + # Restore original config + if original_config: + settings.backend_auth_config[url] = original_config + + +@pytest.mark.asyncio +async def test_exchange_token_missing_token_prefix(): + url = "https://openeo.vito.be" + original_config = settings.backend_auth_config.get(url) + + try: + # Create a config without token_prefix + settings.backend_auth_config[url] = BackendAuthConfig( + auth_method=AuthMethod.USER_CREDENTIALS, + token_provider="openeo", + token_prefix=None, + ) + + with pytest.raises(ValueError, match="must define"): + await exchange_token("user-token", url) + finally: + # Restore original config + if original_config: + settings.backend_auth_config[url] = original_config + + +@pytest.mark.asyncio +@patch( + "app.auth._exchange_token_for_provider", + new_callable=AsyncMock, +) +async def test_exchange_token_success_with_prefix(mock_exchange): + url = "https://openeo.vito.be" + original_config = settings.backend_auth_config.get(url) + + try: + settings.backend_auth_config[url] = BackendAuthConfig( + auth_method=AuthMethod.USER_CREDENTIALS, + token_provider="openeo", + token_prefix="Bearer", + ) + + mock_exchange.return_value = {"access_token": "exchanged-token-123"} + + result = await exchange_token("user-token", url) + + assert result == "Bearer/exchanged-token-123" + mock_exchange.assert_called_once_with( + initial_token="user-token", + provider="openeo", + ) + finally: + if original_config: + settings.backend_auth_config[url] = original_config + + +# Tests for _exchange_token_for_provider function +@pytest.mark.asyncio +async def test_exchange_token_for_provider_missing_client_credentials(): + original_client_id = settings.keycloak_client_id + original_client_secret = settings.keycloak_client_secret + + try: + settings.keycloak_client_id = "" + settings.keycloak_client_secret = "" + + with pytest.raises(AuthException) as exc_info: + await _exchange_token_for_provider("token", "openeo") + + assert exc_info.value.http_status == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "not configured" in exc_info.value.message + finally: + settings.keycloak_client_id = original_client_id + settings.keycloak_client_secret = original_client_secret + + +@pytest.mark.asyncio +@patch( + "app.auth.httpx.AsyncClient", +) +async def test_exchange_token_for_provider_network_error(mock_client_class): + original_client_id = settings.keycloak_client_id + original_client_secret = settings.keycloak_client_secret + + try: + settings.keycloak_client_id = "test-client" + settings.keycloak_client_secret = "test-secret" + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.post.side_effect = httpx.RequestError("Network error") + mock_client_class.return_value = mock_client + + with pytest.raises(AuthException) as exc_info: + await _exchange_token_for_provider("token", "openeo") + + assert exc_info.value.http_status == status.HTTP_502_BAD_GATEWAY + assert "Could not authenticate" in exc_info.value.message + finally: + settings.keycloak_client_id = original_client_id + settings.keycloak_client_secret = original_client_secret + + +@pytest.mark.asyncio +@patch( + "app.auth.httpx.AsyncClient", +) +async def test_exchange_token_for_provider_invalid_json_response(mock_client_class): + original_client_id = settings.keycloak_client_id + original_client_secret = settings.keycloak_client_secret + + try: + settings.keycloak_client_id = "test-client" + settings.keycloak_client_secret = "test-secret" + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + + # Mock response with invalid JSON + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.side_effect = ValueError("Invalid JSON") + mock_client.post.return_value = mock_response + mock_client_class.return_value = mock_client + + with pytest.raises(AuthException) as exc_info: + await _exchange_token_for_provider("token", "openeo") + + assert exc_info.value.http_status == status.HTTP_502_BAD_GATEWAY + assert "Could not authenticate" in exc_info.value.message + finally: + settings.keycloak_client_id = original_client_id + settings.keycloak_client_secret = original_client_secret + + +@pytest.mark.asyncio +@patch( + "app.auth.httpx.AsyncClient", +) +async def test_exchange_token_for_provider_token_exchange_failed(mock_client_class): + original_client_id = settings.keycloak_client_id + original_client_secret = settings.keycloak_client_secret + + try: + settings.keycloak_client_id = "test-client" + settings.keycloak_client_secret = "test-secret" + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + + # Mock response with 401 error + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.json.return_value = { + "error": "unauthorized", + "error_description": "Invalid credentials", + } + mock_response.text = "Unauthorized" + mock_client.post.return_value = mock_response + mock_client_class.return_value = mock_client + + with pytest.raises(AuthException) as exc_info: + await _exchange_token_for_provider("token", "openeo") + + assert exc_info.value.http_status == status.HTTP_401_UNAUTHORIZED + assert "Could not authenticate" in exc_info.value.message + finally: + settings.keycloak_client_id = original_client_id + settings.keycloak_client_secret = original_client_secret + + +@pytest.mark.asyncio +@patch( + "app.auth.httpx.AsyncClient", +) +async def test_exchange_token_for_provider_account_not_linked(mock_client_class): + original_client_id = settings.keycloak_client_id + original_client_secret = settings.keycloak_client_secret + + try: + settings.keycloak_client_id = "test-client" + settings.keycloak_client_secret = "test-secret" + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + + # Mock response with not_linked error + mock_response = MagicMock() + mock_response.status_code = 400 + mock_response.json.return_value = { + "error": "not_linked", + "error_description": "Account not linked", + } + mock_response.text = "Bad Request" + mock_client.post.return_value = mock_response + mock_client_class.return_value = mock_client + + with pytest.raises(AuthException) as exc_info: + await _exchange_token_for_provider("token", "openeo") + + assert exc_info.value.http_status == status.HTTP_401_UNAUTHORIZED + assert "link your account" in exc_info.value.message + assert "Account Dashboard" in exc_info.value.message + finally: + settings.keycloak_client_id = original_client_id + settings.keycloak_client_secret = original_client_secret + + +@pytest.mark.asyncio +@patch( + "app.auth.httpx.AsyncClient", +) +async def test_exchange_token_for_provider_success(mock_client_class): + original_client_id = settings.keycloak_client_id + original_client_secret = settings.keycloak_client_secret + + try: + settings.keycloak_client_id = "test-client" + settings.keycloak_client_secret = "test-secret" + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + + # Mock successful response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "new-platform-token", + "expires_in": 3600, + "token_type": "Bearer", + } + mock_client.post.return_value = mock_response + mock_client_class.return_value = mock_client + + result = await _exchange_token_for_provider("user-token", "openeo") + + assert result["access_token"] == "new-platform-token" + assert result["expires_in"] == 3600 + assert result["token_type"] == "Bearer" + finally: + settings.keycloak_client_id = original_client_id + settings.keycloak_client_secret = original_client_secret