diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java new file mode 100644 index 000000000..8d48c2dff --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DataPlaneTokenSource.java @@ -0,0 +1,115 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.http.HttpClient; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Manages and provides Databricks data plane tokens. This class is responsible for acquiring and + * caching OAuth tokens that are specific to a particular Databricks data plane service endpoint and + * a set of authorization details. It utilizes a {@link DatabricksOAuthTokenSource} for obtaining + * control plane tokens, which may then be exchanged or used to authorize requests for data plane + * tokens. Cached {@link EndpointTokenSource} instances are used to efficiently reuse tokens for + * repeated requests to the same endpoint with the same authorization context. + */ +public class DataPlaneTokenSource { + private final HttpClient httpClient; + private final TokenSource cpTokenSource; + private final String host; + private final ConcurrentHashMap sourcesCache; + /** + * Caching key for {@link EndpointTokenSource}, based on endpoint and authorization details. This + * is a value object that uniquely identifies a token source configuration. + */ + private static final class TokenSourceKey { + /** The target service endpoint URL. */ + private final String endpoint; + + /** Specific authorization details for the endpoint. */ + private final String authDetails; + + /** + * Constructs a TokenSourceKey. + * + * @param endpoint The target service endpoint URL. + * @param authDetails Specific authorization details. + */ + public TokenSourceKey(String endpoint, String authDetails) { + this.endpoint = endpoint; + this.authDetails = authDetails; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TokenSourceKey that = (TokenSourceKey) o; + return Objects.equals(endpoint, that.endpoint) + && Objects.equals(authDetails, that.authDetails); + } + + @Override + public int hashCode() { + return Objects.hash(endpoint, authDetails); + } + } + + /** + * Constructs a DataPlaneTokenSource. + * + * @param httpClient The {@link HttpClient} for token requests. + * @param cpTokenSource The {@link TokenSource} for control plane tokens. + * @param host The host for the token exchange request. + * @throws NullPointerException if any parameter is null. + * @throws IllegalArgumentException if the host is empty. + */ + public DataPlaneTokenSource(HttpClient httpClient, TokenSource cpTokenSource, String host) { + this.httpClient = Objects.requireNonNull(httpClient, "HTTP client cannot be null"); + this.cpTokenSource = + Objects.requireNonNull(cpTokenSource, "Control plane token source cannot be null"); + this.host = Objects.requireNonNull(host, "Host cannot be null"); + + if (host.isEmpty()) { + throw new IllegalArgumentException("Host cannot be empty"); + } + this.sourcesCache = new ConcurrentHashMap<>(); + } + + /** + * Retrieves a token for the specified endpoint and authorization details. It uses a cached {@link + * EndpointTokenSource} if available, otherwise creates and caches a new one. + * + * @param endpoint The target data plane service endpoint. + * @param authDetails Authorization details for the endpoint. + * @return The dataplane {@link Token}. + * @throws NullPointerException if either parameter is null. + * @throws IllegalArgumentException if either parameter is empty. + * @throws DatabricksException if the token request fails. + */ + public Token getToken(String endpoint, String authDetails) { + Objects.requireNonNull(endpoint, "Data plane endpoint URL cannot be null"); + Objects.requireNonNull(authDetails, "Authorization details cannot be null"); + + if (endpoint.isEmpty()) { + throw new IllegalArgumentException("Data plane endpoint URL cannot be empty"); + } + if (authDetails.isEmpty()) { + throw new IllegalArgumentException("Authorization details cannot be empty"); + } + + TokenSourceKey key = new TokenSourceKey(endpoint, authDetails); + + EndpointTokenSource specificSource = + sourcesCache.computeIfAbsent( + key, + k -> + new EndpointTokenSource( + this.cpTokenSource, k.authDetails, this.httpClient, this.host)); + + return specificSource.getToken(); + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java new file mode 100644 index 000000000..3ca75c441 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/EndpointTokenSource.java @@ -0,0 +1,97 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.HttpClient; +import java.time.LocalDateTime; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Represents a token source that exchanges a control plane token for an endpoint-specific dataplane + * token. It utilizes an underlying {@link TokenSource} to obtain the initial control plane token. + */ +public class EndpointTokenSource extends RefreshableTokenSource { + private static final Logger LOG = LoggerFactory.getLogger(EndpointTokenSource.class); + private static final String JWT_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:jwt-bearer"; + private static final String GRANT_TYPE_PARAM = "grant_type"; + private static final String AUTHORIZATION_DETAILS_PARAM = "authorization_details"; + private static final String ASSERTION_PARAM = "assertion"; + private static final String TOKEN_ENDPOINT = "/oidc/v1/token"; + + private final TokenSource cpTokenSource; + private final String authDetails; + private final HttpClient httpClient; + private final String host; + + /** + * Constructs a new EndpointTokenSource. + * + * @param cpTokenSource The {@link TokenSource} used to obtain the control plane token. + * @param authDetails The authorization details required for the token exchange. + * @param httpClient The {@link HttpClient} used to make the token exchange request. + * @param host The host for the token exchange request. + * @throws IllegalArgumentException if authDetails is empty or host is empty. + * @throws NullPointerException if any of the parameters are null. + */ + public EndpointTokenSource( + TokenSource cpTokenSource, String authDetails, HttpClient httpClient, String host) { + this.cpTokenSource = + Objects.requireNonNull(cpTokenSource, "Control plane token source cannot be null"); + this.authDetails = Objects.requireNonNull(authDetails, "Authorization details cannot be null"); + this.httpClient = Objects.requireNonNull(httpClient, "HTTP client cannot be null"); + this.host = Objects.requireNonNull(host, "Host cannot be null"); + + if (authDetails.isEmpty()) { + throw new IllegalArgumentException("Authorization details cannot be empty"); + } + if (host.isEmpty()) { + throw new IllegalArgumentException("Host cannot be empty"); + } + } + + /** + * Fetches an endpoint-specific dataplane token by exchanging a control plane token. + * + *

This method first obtains a control plane token from the configured {@code cpTokenSource}. + * It then uses this token as an assertion along with the provided {@code authDetails} to request + * a new, more scoped dataplane token from the Databricks OAuth token endpoint ({@value + * #TOKEN_ENDPOINT}). + * + * @return A new {@link Token} containing the exchanged dataplane access token, its type, any + * accompanying refresh token, and its expiry time. + * @throws DatabricksException if the token exchange with the OAuth endpoint fails. + * @throws IllegalArgumentException if the token endpoint url is empty. + * @throws NullPointerException if any of the parameters are null. + */ + @Override + protected Token refresh() { + Token cpToken = cpTokenSource.getToken(); + Map params = new HashMap<>(); + params.put(GRANT_TYPE_PARAM, JWT_GRANT_TYPE); + params.put(AUTHORIZATION_DETAILS_PARAM, authDetails); + params.put(ASSERTION_PARAM, cpToken.getAccessToken()); + + OAuthResponse oauthResponse; + try { + oauthResponse = + TokenEndpointClient.requestToken(this.httpClient, this.host + TOKEN_ENDPOINT, params); + } catch (DatabricksException | IllegalArgumentException | NullPointerException e) { + LOG.error( + "Failed to exchange control plane token for dataplane token at endpoint {}: {}", + TOKEN_ENDPOINT, + e.getMessage(), + e); + throw e; + } + + LocalDateTime expiry = LocalDateTime.now().plusSeconds(oauthResponse.getExpiresIn()); + return new Token( + oauthResponse.getAccessToken(), + oauthResponse.getTokenType(), + oauthResponse.getRefreshToken(), + expiry); + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java new file mode 100644 index 000000000..69883dd24 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenEndpointClient.java @@ -0,0 +1,91 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.FormRequest; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.util.Map; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Client for interacting with an OAuth token endpoint. + * + *

This class provides a method to request an OAuth token from a specified token endpoint URL + * using the provided HTTP client and request parameters. It handles the HTTP request and parses the + * JSON response into an {@link OAuthResponse} object. + */ +public final class TokenEndpointClient { + private static final Logger LOG = LoggerFactory.getLogger(TokenEndpointClient.class); + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private TokenEndpointClient() {} + + /** + * Requests an OAuth token from the specified token endpoint. + * + * @param httpClient The {@link HttpClient} to use for making the request. + * @param tokenEndpointUrl The URL of the token endpoint. + * @param params A map of parameters to include in the token request. + * @return An {@link OAuthResponse} containing the token information. + * @throws DatabricksException if an error occurs during the token request or response parsing. + * @throws IllegalArgumentException if the token endpoint URL is empty. + * @throws NullPointerException if any of the parameters are null. + */ + public static OAuthResponse requestToken( + HttpClient httpClient, String tokenEndpointUrl, Map params) + throws DatabricksException { + Objects.requireNonNull(httpClient, "HttpClient cannot be null"); + Objects.requireNonNull(params, "Request parameters map cannot be null"); + Objects.requireNonNull(tokenEndpointUrl, "Token endpoint URL cannot be null"); + + if (tokenEndpointUrl.isEmpty()) { + throw new IllegalArgumentException("Token endpoint URL cannot be empty"); + } + + Response rawResponse; + try { + LOG.debug("Requesting token from endpoint: {}", tokenEndpointUrl); + rawResponse = httpClient.execute(new FormRequest(tokenEndpointUrl, params)); + } catch (IOException e) { + LOG.error("Failed to request token from {}: {}", tokenEndpointUrl, e.getMessage(), e); + throw new DatabricksException( + String.format("Failed to request token from %s: %s", tokenEndpointUrl, e.getMessage()), + e); + } + + OAuthResponse response; + try { + response = OBJECT_MAPPER.readValue(rawResponse.getBody(), OAuthResponse.class); + } catch (IOException e) { + LOG.error( + "Failed to parse OAuth response from token endpoint {}: {}", + tokenEndpointUrl, + e.getMessage(), + e); + throw new DatabricksException( + String.format( + "Failed to parse OAuth response from token endpoint %s: %s", + tokenEndpointUrl, e.getMessage()), + e); + } + + if (response.getErrorCode() != null) { + String errorSummary = + response.getErrorSummary() != null ? response.getErrorSummary() : "No summary provided."; + LOG.error( + "Token request to {} failed with error: {} - {}", + tokenEndpointUrl, + response.getErrorCode(), + errorSummary); + throw new DatabricksException( + String.format( + "Token request failed with error: %s - %s", response.getErrorCode(), errorSummary)); + } + LOG.debug("Successfully obtained token response from {}", tokenEndpointUrl); + return response; + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java new file mode 100644 index 000000000..5887c4ee1 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DataPlaneTokenSourceTest.java @@ -0,0 +1,256 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import java.io.IOException; +import java.net.URL; +import java.time.LocalDateTime; +import java.util.stream.Stream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.MockedConstruction; + +public class DataPlaneTokenSourceTest { + private static final String TEST_ENDPOINT_1 = "https://endpoint1.databricks.com/"; + private static final String TEST_ENDPOINT_2 = "https://endpoint2.databricks.com/"; + private static final String TEST_AUTH_DETAILS_1 = "{\"aud\":\"aud1\"}"; + private static final String TEST_AUTH_DETAILS_2 = "{\"aud\":\"aud2\"}"; + private static final String TEST_CP_TOKEN = "cp-access-token"; + private static final String TEST_TOKEN_TYPE = "Bearer"; + private static final String TEST_REFRESH_TOKEN = "refresh-token"; + private static final int TEST_EXPIRES_IN = 3600; + private static final String TEST_HOST = "https://test.databricks.com"; + + private static Stream provideDataPlaneTokenScenarios() throws Exception { + // Mock DatabricksOAuthTokenSource for control plane token + Token cpToken = + new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, null, LocalDateTime.now().plusSeconds(600)); + DatabricksOAuthTokenSource mockCpTokenSource = mock(DatabricksOAuthTokenSource.class); + when(mockCpTokenSource.getToken()).thenReturn(cpToken); + + // Success JSON for endpoint1/auth1 + String successJson1 = + "{" + + "\"access_token\":\"dp-access-token1\"," + + "\"token_type\":\"Bearer\"," + + "\"refresh_token\":\"refresh-token\"," + + "\"expires_in\":3600" + + "}"; + HttpClient mockSuccessClient1 = mock(HttpClient.class); + when(mockSuccessClient1.execute(any())) + .thenReturn(new Response(successJson1, 200, "OK", new URL(TEST_ENDPOINT_1))); + + // Success JSON for endpoint2/auth2 + String successJson2 = + "{" + + "\"access_token\":\"dp-access-token2\"," + + "\"token_type\":\"Bearer\"," + + "\"refresh_token\":\"refresh-token\"," + + "\"expires_in\":3600" + + "}"; + HttpClient mockSuccessClient2 = mock(HttpClient.class); + when(mockSuccessClient2.execute(any())) + .thenReturn(new Response(successJson2, 200, "OK", new URL(TEST_ENDPOINT_2))); + + String errorJson = + "{" + "\"error\":\"invalid_request\"," + "\"error_description\":\"Bad request\"" + "}"; + HttpClient mockErrorClient = mock(HttpClient.class); + when(mockErrorClient.execute(any())) + .thenReturn(new Response(errorJson, 400, "Bad Request", new URL(TEST_ENDPOINT_1))); + + // IOException scenario + HttpClient mockIOExceptionClient = mock(HttpClient.class); + when(mockIOExceptionClient.execute(any())).thenThrow(new IOException("Network error")); + + // For null/empty endpoint or authDetails + return Stream.of( + Arguments.of( + "Success: endpoint1/auth1", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + mockCpTokenSource, + TEST_HOST, + new Token( + "dp-access-token1", + TEST_TOKEN_TYPE, + TEST_REFRESH_TOKEN, + LocalDateTime.now().plusSeconds(TEST_EXPIRES_IN)), + null // No exception + ), + Arguments.of( + "Success: endpoint2/auth2 (different cache key)", + TEST_ENDPOINT_2, + TEST_AUTH_DETAILS_2, + mockSuccessClient2, + mockCpTokenSource, + TEST_HOST, + new Token( + "dp-access-token2", + TEST_TOKEN_TYPE, + TEST_REFRESH_TOKEN, + LocalDateTime.now().plusSeconds(TEST_EXPIRES_IN)), + null), + Arguments.of( + "Error response from endpoint", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockErrorClient, + mockCpTokenSource, + TEST_HOST, + null, + com.databricks.sdk.core.DatabricksException.class), + Arguments.of( + "IOException from HttpClient", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockIOExceptionClient, + mockCpTokenSource, + TEST_HOST, + null, + com.databricks.sdk.core.DatabricksException.class), + Arguments.of( + "Null cpTokenSource", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + null, + TEST_HOST, + null, + NullPointerException.class), + Arguments.of( + "Null httpClient", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + null, + mockCpTokenSource, + TEST_HOST, + null, + NullPointerException.class), + Arguments.of( + "Null endpoint", + null, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + mockCpTokenSource, + TEST_HOST, + null, + NullPointerException.class), + Arguments.of( + "Null authDetails", + TEST_ENDPOINT_1, + null, + mockSuccessClient1, + mockCpTokenSource, + TEST_HOST, + null, + NullPointerException.class), + Arguments.of( + "Null host", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + mockCpTokenSource, + null, + null, + NullPointerException.class), + Arguments.of( + "Empty host", + TEST_ENDPOINT_1, + TEST_AUTH_DETAILS_1, + mockSuccessClient1, + mockCpTokenSource, + "", + null, + IllegalArgumentException.class)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideDataPlaneTokenScenarios") + void testDataPlaneTokenSource( + String testName, + String endpoint, + String authDetails, + HttpClient httpClient, + DatabricksOAuthTokenSource cpTokenSource, + String host, + Token expectedToken, + Class expectedException) { + if (expectedException != null) { + assertThrows( + expectedException, + () -> { + DataPlaneTokenSource source = new DataPlaneTokenSource(httpClient, cpTokenSource, host); + source.getToken(endpoint, authDetails); + }); + } else { + DataPlaneTokenSource source = new DataPlaneTokenSource(httpClient, cpTokenSource, host); + Token token = source.getToken(endpoint, authDetails); + assertNotNull(token); + assertEquals(expectedToken.getAccessToken(), token.getAccessToken()); + assertEquals(expectedToken.getTokenType(), token.getTokenType()); + assertEquals(expectedToken.getRefreshToken(), token.getRefreshToken()); + assertTrue(token.isValid()); + } + } + + @Test + void testEndpointTokenSourceCaching() throws Exception { + Token cpToken = + new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, null, LocalDateTime.now().plusSeconds(3600)); + DatabricksOAuthTokenSource mockCpTokenSource = mock(DatabricksOAuthTokenSource.class); + when(mockCpTokenSource.getToken()).thenReturn(cpToken); + + String successJson = + "{\"access_token\":\"dp-access-token\",\"token_type\":\"Bearer\",\"refresh_token\":\"refresh-token\",\"expires_in\":3600}"; + HttpClient mockHttpClient = mock(HttpClient.class); + when(mockHttpClient.execute(any())) + .thenReturn(new Response(successJson, 200, "OK", new URL(TEST_ENDPOINT_1))); + + try (MockedConstruction mockedConstruction = + mockConstruction(EndpointTokenSource.class)) { + DataPlaneTokenSource source = + new DataPlaneTokenSource(mockHttpClient, mockCpTokenSource, TEST_HOST); + + // First call - should create new EndpointTokenSource + source.getToken(TEST_ENDPOINT_1, TEST_AUTH_DETAILS_1); + assertEquals( + 1, + mockedConstruction.constructed().size(), + "First call should create one EndpointTokenSource"); + + // Second call with same endpoint and auth details - should reuse existing EndpointTokenSource + source.getToken(TEST_ENDPOINT_1, TEST_AUTH_DETAILS_1); + assertEquals( + 1, + mockedConstruction.constructed().size(), + "This call should reuse the existing EndpointTokenSource"); + + // Call with different endpoint - should create new EndpointTokenSource + source.getToken(TEST_ENDPOINT_2, TEST_AUTH_DETAILS_2); + assertEquals( + 2, + mockedConstruction.constructed().size(), + "Different endpoint should create new EndpointTokenSource"); + + // Call with different auth details - should create new EndpointTokenSource + source.getToken(TEST_ENDPOINT_1, TEST_AUTH_DETAILS_2); + assertEquals( + 3, + mockedConstruction.constructed().size(), + "Different auth details should create new EndpointTokenSource"); + + source.getToken(TEST_ENDPOINT_2, TEST_AUTH_DETAILS_2); + assertEquals( + 3, + mockedConstruction.constructed().size(), + "This call should reuse the existing EndpointTokenSource"); + } + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java new file mode 100644 index 000000000..a3af2254f --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/EndpointTokenSourceTest.java @@ -0,0 +1,220 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import java.io.IOException; +import java.net.URL; +import java.time.LocalDateTime; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +class EndpointTokenSourceTest { + private static final String TEST_AUTH_DETAILS = "{\"aud\":\"test-audience\"}"; + private static final String TEST_CP_TOKEN = "cp-access-token"; + private static final String TEST_DP_TOKEN = "dp-access-token"; + private static final String TEST_TOKEN_TYPE = "Bearer"; + private static final String TEST_REFRESH_TOKEN = "refresh-token"; + private static final int TEST_EXPIRES_IN = 3600; + private static final String TEST_HOST = "https://test.databricks.com"; + + private static Stream provideEndpointTokenScenarios() throws Exception { + String successJson = + "{" + + "\"access_token\":\"" + + TEST_DP_TOKEN + + "\"," + + "\"token_type\":\"" + + TEST_TOKEN_TYPE + + "\"," + + "\"expires_in\":" + + TEST_EXPIRES_IN + + "," + + "\"refresh_token\":\"" + + TEST_REFRESH_TOKEN + + "\"}"; + + String errorJson = + "{" + + "\"error\":\"invalid_client\"," + + "\"error_description\":\"Client authentication failed\"}"; + + String malformedJson = "{not valid json}"; + + // Mock DatabricksOAuthTokenSource for control plane token + Token cpToken = new Token(TEST_CP_TOKEN, TEST_TOKEN_TYPE, LocalDateTime.now().plusMinutes(10)); + DatabricksOAuthTokenSource mockCpTokenSource = mock(DatabricksOAuthTokenSource.class); + when(mockCpTokenSource.getToken()).thenReturn(cpToken); + + // Mock HttpClient for success + HttpClient mockSuccessClient = mock(HttpClient.class); + when(mockSuccessClient.execute(any())) + .thenReturn(new Response(successJson, 200, "OK", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for error response + HttpClient mockErrorClient = mock(HttpClient.class); + when(mockErrorClient.execute(any())) + .thenReturn( + new Response(errorJson, 400, "Bad Request", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for malformed JSON + HttpClient mockMalformedClient = mock(HttpClient.class); + when(mockMalformedClient.execute(any())) + .thenReturn( + new Response(malformedJson, 200, "OK", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for IOException + HttpClient mockIOExceptionClient = mock(HttpClient.class); + when(mockIOExceptionClient.execute(any())).thenThrow(new IOException("Network error")); + + return Stream.of( + Arguments.of( + "Success response", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockSuccessClient, + TEST_HOST, + null, // No exception expected + TEST_DP_TOKEN, + TEST_TOKEN_TYPE, + TEST_REFRESH_TOKEN, + TEST_EXPIRES_IN), + Arguments.of( + "OAuth error response", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockErrorClient, + TEST_HOST, + DatabricksException.class, + null, + null, + null, + 0), + Arguments.of( + "Malformed JSON response", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockMalformedClient, + TEST_HOST, + DatabricksException.class, + null, + null, + null, + 0), + Arguments.of( + "IOException from HttpClient", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockIOExceptionClient, + TEST_HOST, + DatabricksException.class, + null, + null, + null, + 0), + Arguments.of( + "Null cpTokenSource", + null, + TEST_AUTH_DETAILS, + mockSuccessClient, + TEST_HOST, + NullPointerException.class, + null, + null, + null, + 0), + Arguments.of( + "Null authDetails", + mockCpTokenSource, + null, + mockSuccessClient, + TEST_HOST, + NullPointerException.class, + null, + null, + null, + 0), + Arguments.of( + "Empty authDetails", + mockCpTokenSource, + "", + mockSuccessClient, + TEST_HOST, + IllegalArgumentException.class, + null, + null, + null, + 0), + Arguments.of( + "Null httpClient", + mockCpTokenSource, + TEST_AUTH_DETAILS, + null, + TEST_HOST, + NullPointerException.class, + null, + null, + null, + 0), + Arguments.of( + "Null host", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockSuccessClient, + null, + NullPointerException.class, + null, + null, + null, + 0), + Arguments.of( + "Empty host", + mockCpTokenSource, + TEST_AUTH_DETAILS, + mockSuccessClient, + "", + IllegalArgumentException.class, + null, + null, + null, + 0)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideEndpointTokenScenarios") + void testEndpointTokenSource( + String testName, + DatabricksOAuthTokenSource cpTokenSource, + String authDetails, + HttpClient httpClient, + String host, + Class expectedException, + String expectedAccessToken, + String expectedTokenType, + String expectedRefreshToken, + int expectedExpiresIn) { + if (expectedException != null) { + assertThrows( + expectedException, + () -> { + EndpointTokenSource source = + new EndpointTokenSource(cpTokenSource, authDetails, httpClient, host); + source.getToken(); + }); + } else { + EndpointTokenSource source = + new EndpointTokenSource(cpTokenSource, authDetails, httpClient, host); + Token token = source.getToken(); + assertNotNull(token); + assertEquals(expectedAccessToken, token.getAccessToken()); + assertEquals(expectedTokenType, token.getTokenType()); + assertEquals(expectedRefreshToken, token.getRefreshToken()); + } + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java new file mode 100644 index 000000000..581c90143 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/TokenEndpointClientTest.java @@ -0,0 +1,171 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +import com.databricks.sdk.core.DatabricksException; +import com.databricks.sdk.core.http.FormRequest; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.http.Response; +import java.io.IOException; +import java.net.URL; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +class TokenEndpointClientTest { + private static final String TOKEN_ENDPOINT_URL = "https://test.databricks.com/oauth/token"; + private static final Map PARAMS = new HashMap<>(); + + private static Stream provideTokenScenarios() throws Exception { + // Success response JSON + String successJson = + "{" + + "\"access_token\":\"test-access-token\"," + + "\"token_type\":\"Bearer\"," + + "\"expires_in\":3600," + + "\"refresh_token\":\"test-refresh-token\"}"; + // Error response JSON + String errorJson = + "{" + + "\"error\":\"invalid_client\"," + + "\"error_description\":\"Client authentication failed\"}"; + // Malformed JSON + String malformedJson = "{not valid json}"; + + // Mock HttpClient for success + HttpClient mockSuccessClient = mock(HttpClient.class); + when(mockSuccessClient.execute(any(FormRequest.class))) + .thenReturn(new Response(successJson, 200, "OK", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for error response + HttpClient mockErrorClient = mock(HttpClient.class); + when(mockErrorClient.execute(any(FormRequest.class))) + .thenReturn( + new Response(errorJson, 400, "Bad Request", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for malformed JSON + HttpClient mockMalformedClient = mock(HttpClient.class); + when(mockMalformedClient.execute(any(FormRequest.class))) + .thenReturn( + new Response(malformedJson, 200, "OK", new URL("https://test.databricks.com/"))); + + // Mock HttpClient for IOException + HttpClient mockIOExceptionClient = mock(HttpClient.class); + when(mockIOExceptionClient.execute(any(FormRequest.class))) + .thenThrow(new IOException("Network error")); + + return Stream.of( + Arguments.of( + "Success response", + mockSuccessClient, + TOKEN_ENDPOINT_URL, + PARAMS, + null, // No exception expected + "test-access-token", + "Bearer", + 3600, + "test-refresh-token"), + Arguments.of( + "OAuth error response", + mockErrorClient, + TOKEN_ENDPOINT_URL, + PARAMS, + DatabricksException.class, + null, + null, + 0, + null), + Arguments.of( + "Malformed JSON response", + mockMalformedClient, + TOKEN_ENDPOINT_URL, + PARAMS, + DatabricksException.class, + null, + null, + 0, + null), + Arguments.of( + "IOException from HttpClient", + mockIOExceptionClient, + TOKEN_ENDPOINT_URL, + PARAMS, + DatabricksException.class, + null, + null, + 0, + null), + Arguments.of( + "Null HttpClient", + null, + TOKEN_ENDPOINT_URL, + PARAMS, + NullPointerException.class, + null, + null, + 0, + null), + Arguments.of( + "Null tokenEndpointUrl", + mockSuccessClient, + null, + PARAMS, + NullPointerException.class, + null, + null, + 0, + null), + Arguments.of( + "Empty tokenEndpointUrl", + mockSuccessClient, + "", + PARAMS, + IllegalArgumentException.class, + null, + null, + 0, + null), + Arguments.of( + "Null params", + mockSuccessClient, + TOKEN_ENDPOINT_URL, + null, + NullPointerException.class, + null, + null, + 0, + null)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideTokenScenarios") + void testRequestToken( + String testName, + HttpClient httpClient, + String tokenEndpointUrl, + Map params, + Class expectedException, + String expectedAccessToken, + String expectedTokenType, + int expectedExpiresIn, + String expectedRefreshToken) { + if (expectedException != null) { + assertThrows( + expectedException, + () -> TokenEndpointClient.requestToken(httpClient, tokenEndpointUrl, params)); + } else { + OAuthResponse response = + TokenEndpointClient.requestToken(httpClient, tokenEndpointUrl, params); + assertNotNull(response); + assertEquals(expectedAccessToken, response.getAccessToken()); + assertEquals(expectedTokenType, response.getTokenType()); + assertEquals(expectedExpiresIn, response.getExpiresIn()); + assertEquals(expectedRefreshToken, response.getRefreshToken()); + } + } +}