diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 808c01fcf..96abd5fc1 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -3,6 +3,7 @@ ## Release v0.52.0 ### New Features and Improvements +* Added Direct-to-Dataplane API support, allowing users to query route optimized model serving endpoints ([#453](https://github.com/databricks/databricks-sdk-java/pull/453)). ### Bug Fixes diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/WorkspaceClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/WorkspaceClient.java index 559a5eabf..d4c066a69 100755 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/WorkspaceClient.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/WorkspaceClient.java @@ -141,6 +141,8 @@ import com.databricks.sdk.service.pipelines.PipelinesAPI; import com.databricks.sdk.service.pipelines.PipelinesService; import com.databricks.sdk.service.serving.ServingEndpointsAPI; +import com.databricks.sdk.service.serving.ServingEndpointsDataPlaneAPI; +import com.databricks.sdk.service.serving.ServingEndpointsDataPlaneService; import com.databricks.sdk.service.serving.ServingEndpointsService; import com.databricks.sdk.service.settings.CredentialsManagerAPI; import com.databricks.sdk.service.settings.CredentialsManagerService; @@ -297,6 +299,7 @@ public class WorkspaceClient { private SecretsExt secretsAPI; private ServicePrincipalsAPI servicePrincipalsAPI; private ServingEndpointsAPI servingEndpointsAPI; + private ServingEndpointsDataPlaneAPI servingEndpointsDataPlaneAPI; private SettingsAPI settingsAPI; private SharesAPI sharesAPI; private StatementExecutionAPI statementExecutionAPI; @@ -324,7 +327,6 @@ public WorkspaceClient() { public WorkspaceClient(DatabricksConfig config) { this.config = config; apiClient = new ApiClient(config); - accessControlAPI = new AccessControlAPI(apiClient); accountAccessControlProxyAPI = new AccountAccessControlProxyAPI(apiClient); alertsAPI = new AlertsAPI(apiClient); @@ -407,6 +409,8 @@ public WorkspaceClient(DatabricksConfig config) { secretsAPI = new SecretsExt(apiClient); servicePrincipalsAPI = new ServicePrincipalsAPI(apiClient); servingEndpointsAPI = new ServingEndpointsAPI(apiClient); + servingEndpointsDataPlaneAPI = + new ServingEndpointsDataPlaneAPI(apiClient, config, servingEndpointsAPI); settingsAPI = new SettingsAPI(apiClient); sharesAPI = new SharesAPI(apiClient); statementExecutionAPI = new StatementExecutionAPI(apiClient); @@ -1458,6 +1462,14 @@ public ServingEndpointsAPI servingEndpoints() { return servingEndpointsAPI; } + /** + * Serving endpoints DataPlane provides a set of operations to interact with data plane endpoints + * for Serving endpoints service. + */ + public ServingEndpointsDataPlaneAPI servingEndpointsDataPlane() { + return servingEndpointsDataPlaneAPI; + } + /** Workspace Settings API allows users to manage settings at the workspace level. */ public SettingsAPI settings() { return settingsAPI; @@ -2701,6 +2713,20 @@ public WorkspaceClient withServingEndpointsAPI(ServingEndpointsAPI servingEndpoi return this; } + /** Replace the default ServingEndpointsDataPlaneService with a custom implementation. */ + public WorkspaceClient withServingEndpointsDataPlaneImpl( + ServingEndpointsDataPlaneService servingEndpointsDataPlane) { + return this.withServingEndpointsDataPlaneAPI( + new ServingEndpointsDataPlaneAPI(servingEndpointsDataPlane)); + } + + /** Replace the default ServingEndpointsDataPlaneAPI with a custom implementation. */ + public WorkspaceClient withServingEndpointsDataPlaneAPI( + ServingEndpointsDataPlaneAPI servingEndpointsDataPlane) { + this.servingEndpointsDataPlaneAPI = servingEndpointsDataPlane; + return this; + } + /** Replace the default SettingsService with a custom implementation. */ public WorkspaceClient withSettingsImpl(SettingsService settings) { return this.withSettingsAPI(new SettingsAPI(settings)); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/ApiClient.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/ApiClient.java index a45590b4d..2d4eeadc0 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/ApiClient.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/ApiClient.java @@ -4,6 +4,7 @@ import com.databricks.sdk.core.error.PrivateLinkInfo; import com.databricks.sdk.core.http.HttpClient; import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.http.RequestOptions; import com.databricks.sdk.core.http.Response; import com.databricks.sdk.core.retry.RequestBasedRetryStrategyPicker; import com.databricks.sdk.core.retry.RetryStrategy; @@ -50,7 +51,6 @@ public Builder withDatabricksConfig(DatabricksConfig config) { this.accountId = config.getAccountId(); this.retryStrategyPicker = new RequestBasedRetryStrategyPicker(config.getHost()); this.isDebugHeaders = config.isDebugHeaders(); - return this; } @@ -173,7 +173,7 @@ public Map getStringMap(Request req) { protected O withJavaType(Request request, JavaType javaType) { try { - Response response = getResponse(request); + Response response = executeInner(request, request.getUrl(), new RequestOptions()); return deserialize(response.getBody(), javaType); } catch (IOException e) { throw new DatabricksException("IO error: " + e.getMessage(), e); @@ -181,25 +181,34 @@ protected O withJavaType(Request request, JavaType javaType) { } /** - * Executes HTTP request with retries and converts it to proper POJO + * Executes HTTP request with retries and converts it to proper POJO. * * @param in Commons HTTP request * @param target Expected pojo type * @return POJO of requested type */ public T execute(Request in, Class target) throws IOException { - Response out = getResponse(in); + return execute(in, target, new RequestOptions()); + } + + /** + * Executes HTTP request with retries and converts it to proper POJO, using custom request + * options. + * + * @param in Commons HTTP request + * @param target Expected pojo type + * @param options Optional request options to customize request behavior + * @return POJO of requested type + */ + public T execute(Request in, Class target, RequestOptions options) throws IOException { + Response out = executeInner(in, in.getUrl(), options); if (target == Void.class) { return null; } return deserialize(out, target); } - private Response getResponse(Request in) { - return executeInner(in, in.getUrl()); - } - - private Response executeInner(Request in, String path) { + private Response executeInner(Request in, String path, RequestOptions options) { RetryStrategy retryStrategy = retryStrategyPicker.getRetryStrategy(in); int attemptNumber = 0; while (true) { @@ -224,6 +233,8 @@ private Response executeInner(Request in, String path) { } in.withHeader("User-Agent", userAgent); + options.applyOptions(in); + // Make the request, catching any exceptions, as we may want to retry. try { out = httpClient.execute(in); @@ -434,4 +445,8 @@ public String serialize(Object body) throws JsonProcessingException { } return mapper.writeValueAsString(body); } + + public HttpClient getHttpClient() { + return httpClient; + } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java index b08dfb8b4..6ed5b83d3 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/AzureCliCredentialsProvider.java @@ -1,5 +1,6 @@ package com.databricks.sdk.core; +import com.databricks.sdk.core.oauth.OAuthHeaderFactory; import com.databricks.sdk.core.oauth.Token; import com.databricks.sdk.core.utils.AzureUtils; import com.fasterxml.jackson.databind.ObjectMapper; @@ -68,7 +69,7 @@ private Optional getSubscription(DatabricksConfig config) { } @Override - public HeaderFactory configure(DatabricksConfig config) { + public OAuthHeaderFactory configure(DatabricksConfig config) { if (!config.isAzure()) { return null; } @@ -86,15 +87,17 @@ public HeaderFactory configure(DatabricksConfig config) { mgmtTokenSource = null; } CliTokenSource finalMgmtTokenSource = mgmtTokenSource; - return () -> { - Token token = tokenSource.getToken(); - Map headers = new HashMap<>(); - headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken()); - if (finalMgmtTokenSource != null) { - AzureUtils.addSpManagementToken(finalMgmtTokenSource, headers); - } - return AzureUtils.addWorkspaceResourceId(config, headers); - }; + return OAuthHeaderFactory.fromSuppliers( + tokenSource::getToken, + () -> { + Token token = tokenSource.getToken(); + Map headers = new HashMap<>(); + headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken()); + if (finalMgmtTokenSource != null) { + AzureUtils.addSpManagementToken(finalMgmtTokenSource, headers); + } + return AzureUtils.addWorkspaceResourceId(config, headers); + }); } catch (DatabricksException e) { String stderr = e.getMessage(); if (stderr.contains("not found")) { diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java index c20ac2891..655d0b599 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java @@ -1,6 +1,6 @@ package com.databricks.sdk.core; -import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.oauth.OAuthHeaderFactory; import com.databricks.sdk.core.utils.OSUtils; import java.util.*; import org.slf4j.Logger; @@ -36,7 +36,7 @@ private CliTokenSource getDatabricksCliTokenSource(DatabricksConfig config) { } @Override - public HeaderFactory configure(DatabricksConfig config) { + public OAuthHeaderFactory configure(DatabricksConfig config) { String host = config.getHost(); if (host == null) { return null; @@ -48,12 +48,7 @@ public HeaderFactory configure(DatabricksConfig config) { return null; } tokenSource.getToken(); // We need this for checking if databricks CLI is installed. - return () -> { - Token token = tokenSource.getToken(); - Map headers = new HashMap<>(); - headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken()); - return headers; - }; + return OAuthHeaderFactory.fromTokenSource(tokenSource); } catch (DatabricksException e) { String stderr = e.getMessage(); if (stderr.contains("not found")) { diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java index 98d75d4bc..de6548982 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksConfig.java @@ -4,7 +4,10 @@ import com.databricks.sdk.core.http.HttpClient; import com.databricks.sdk.core.http.Request; import com.databricks.sdk.core.http.Response; +import com.databricks.sdk.core.oauth.ErrorTokenSource; +import com.databricks.sdk.core.oauth.OAuthHeaderFactory; import com.databricks.sdk.core.oauth.OpenIDConnectEndpoints; +import com.databricks.sdk.core.oauth.TokenSource; import com.databricks.sdk.core.utils.Cloud; import com.databricks.sdk.core.utils.Environment; import com.fasterxml.jackson.databind.ObjectMapper; @@ -209,6 +212,24 @@ public synchronized Map authenticate() throws DatabricksExceptio } } + public TokenSource getTokenSource() { + if (headerFactory == null) { + try { + ConfigLoader.fixHostIfNeeded(this); + headerFactory = credentialsProvider.configure(this); + } catch (Exception e) { + return new ErrorTokenSource("Failed to get token source: " + e.getMessage()); + } + setAuthType(credentialsProvider.authType()); + } + + if (headerFactory instanceof OAuthHeaderFactory) { + return (TokenSource) headerFactory; + } + return new ErrorTokenSource( + String.format("OAuth Token not supported for current auth type %s", authType)); + } + public CredentialsProvider getCredentialsProvider() { return this.credentialsProvider; } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java index 0e4723f36..f72aa435b 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java @@ -34,14 +34,6 @@ public NamedIDTokenSource(String name, IDTokenSource idTokenSource) { this.name = name; this.idTokenSource = idTokenSource; } - - public String getName() { - return name; - } - - public IDTokenSource getIdTokenSource() { - return idTokenSource; - } } public DefaultCredentialsProvider() {} @@ -143,14 +135,13 @@ private void addOIDCCredentialsProviders(DatabricksConfig config) { config.getClientId(), config.getHost(), endpoints, - namedIdTokenSource.getIdTokenSource(), + namedIdTokenSource.idTokenSource, config.getHttpClient()) .audience(config.getTokenAudience()) .accountId(config.isAccountClient() ? config.getAccountId() : null) .build(); - providers.add( - new TokenSourceCredentialsProvider(oauthTokenSource, namedIdTokenSource.getName())); + providers.add(new TokenSourceCredentialsProvider(oauthTokenSource, namedIdTokenSource.name)); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/http/RequestOptions.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/http/RequestOptions.java new file mode 100644 index 000000000..a9f6c487c --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/http/RequestOptions.java @@ -0,0 +1,71 @@ +package com.databricks.sdk.core.http; + +import java.util.function.Function; + +/** + * A builder class for configuring HTTP request transformations including authentication, URL, and + * user agent headers. + * + *

Experimental: this class is experimental and subject to change in backward incompatible ways. + */ +public class RequestOptions { + private Function authenticateFunc; + private Function urlFunc; + private Function userAgentFunc; + + /** + * Constructs a new RequestOptions instance with default identity functions. Initially, all + * transformations are set to pass through the request unchanged. + */ + public RequestOptions() { + // Default to identity functions + this.authenticateFunc = request -> request; + this.urlFunc = request -> request; + this.userAgentFunc = request -> request; + } + + /** + * Sets the authorization header for the request. + * + * @param authorization The authorization value to be set in the header + * @return This RequestOptions instance for method chaining + */ + public RequestOptions withAuthorization(String authorization) { + this.authenticateFunc = request -> request.withHeader("Authorization", authorization); + return this; + } + + /** + * Sets the URL for the request. + * + * @param url The URL to be set for the request + * @return This RequestOptions instance for method chaining + */ + public RequestOptions withUrl(String url) { + this.urlFunc = request -> request.withUrl(url); + return this; + } + + /** + * Sets the User-Agent header for the request. + * + * @param userAgent The user agent string to be set in the header + * @return This RequestOptions instance for method chaining + */ + public RequestOptions withUserAgent(String userAgent) { + this.userAgentFunc = request -> request.withHeader("User-Agent", userAgent); + return this; + } + + /** + * Applies all configured transformations to the given request. The transformations are applied in + * the following order: 1. Authentication 2. URL 3. User-Agent + * + * @param request The original request to be transformed + * @return A new Request instance with all transformations applied + */ + public Request applyOptions(Request request) { + // Apply all transformation functions in sequence + return userAgentFunc.apply(urlFunc.apply(authenticateFunc.apply(request))); + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProvider.java index 316667114..b29b5aa0e 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProvider.java @@ -6,8 +6,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; import java.util.Optional; /** @@ -25,7 +23,7 @@ public String authType() { } @Override - public HeaderFactory configure(DatabricksConfig config) { + public OAuthHeaderFactory configure(DatabricksConfig config) { if (!config.isAzure() || config.getAzureClientId() == null || config.getAzureTenantId() == null @@ -49,11 +47,7 @@ public HeaderFactory configure(DatabricksConfig config) { idToken.get(), "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); - return () -> { - Map headers = new HashMap<>(); - headers.put("Authorization", "Bearer " + tokenSource.getToken().getAccessToken()); - return headers; - }; + return OAuthHeaderFactory.fromTokenSource(tokenSource); } /** diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java index 432046777..c7c7bb672 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/AzureServicePrincipalCredentialsProvider.java @@ -19,7 +19,7 @@ public String authType() { } @Override - public HeaderFactory configure(DatabricksConfig config) { + public OAuthHeaderFactory configure(DatabricksConfig config) { if (!config.isAzure() || config.getAzureClientId() == null || config.getAzureClientSecret() == null @@ -32,13 +32,15 @@ public HeaderFactory configure(DatabricksConfig config) { RefreshableTokenSource cloud = tokenSourceFor(config, config.getAzureEnvironment().getServiceManagementEndpoint()); - return () -> { - Map headers = new HashMap<>(); - headers.put("Authorization", "Bearer " + inner.getToken().getAccessToken()); - AzureUtils.addWorkspaceResourceId(config, headers); - AzureUtils.addSpManagementToken(cloud, headers); - return headers; - }; + return OAuthHeaderFactory.fromSuppliers( + inner::getToken, + () -> { + Map headers = new HashMap<>(); + headers.put("Authorization", "Bearer " + inner.getToken().getAccessToken()); + AzureUtils.addWorkspaceResourceId(config, headers); + AzureUtils.addSpManagementToken(cloud, headers); + return headers; + }); } /** diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java index e642159c0..f16ae2aed 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSource.java @@ -1,15 +1,12 @@ package com.databricks.sdk.core.oauth; import com.databricks.sdk.core.DatabricksException; -import com.databricks.sdk.core.http.FormRequest; import com.databricks.sdk.core.http.HttpClient; -import com.databricks.sdk.core.http.Response; -import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Strings; -import java.io.IOException; import java.time.LocalDateTime; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -44,8 +41,6 @@ public class DatabricksOAuthTokenSource extends RefreshableTokenSource { private static final String SCOPE_PARAM = "scope"; private static final String CLIENT_ID_PARAM = "client_id"; - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); - private DatabricksOAuthTokenSource(Builder builder) { this.clientId = builder.clientId; this.host = builder.host; @@ -123,44 +118,29 @@ public DatabricksOAuthTokenSource build() { } } - /** - * Validates that a value is non-null for required fields. If the value is a string, it also - * checks that it is non-empty. - * - * @param value The value to validate. - * @param fieldName The name of the field being validated. - * @throws IllegalArgumentException when the value is null or an empty string. - */ - private static void validate(Object value, String fieldName) { - if (value == null) { - LOG.error("Required parameter '{}' is null", fieldName); - throw new IllegalArgumentException( - String.format("Required parameter '%s' cannot be null", fieldName)); - } - if (value instanceof String && ((String) value).isEmpty()) { - LOG.error("Required parameter '{}' is empty", fieldName); - throw new IllegalArgumentException( - String.format("Required parameter '%s' cannot be empty", fieldName)); - } - } - /** * Retrieves an OAuth token by exchanging an ID token. Implements the OAuth token exchange flow to * obtain an access token. * * @return A Token containing the access token and related information. * @throws DatabricksException when the token exchange fails. - * @throws IllegalArgumentException when there is an error code in the response or when required - * parameters are missing. + * @throws IllegalArgumentException when the required string parameters are empty. + * @throws NullPointerException when any of the required parameters are null. */ @Override public Token refresh() { - // Validate all required parameters - validate(clientId, "ClientID"); - validate(host, "Host"); - validate(endpoints, "Endpoints"); - validate(idTokenSource, "IDTokenSource"); - validate(httpClient, "HttpClient"); + Objects.requireNonNull(clientId, "ClientID cannot be null"); + Objects.requireNonNull(host, "Host cannot be null"); + Objects.requireNonNull(endpoints, "Endpoints cannot be null"); + Objects.requireNonNull(idTokenSource, "IDTokenSource cannot be null"); + Objects.requireNonNull(httpClient, "HttpClient cannot be null"); + + if (clientId.isEmpty()) { + throw new IllegalArgumentException("ClientID cannot be empty"); + } + if (host.isEmpty()) { + throw new IllegalArgumentException("Host cannot be empty"); + } String effectiveAudience = determineAudience(); IDToken idToken = idTokenSource.getIDToken(effectiveAudience); @@ -172,47 +152,20 @@ public Token refresh() { params.put(SCOPE_PARAM, SCOPE); params.put(CLIENT_ID_PARAM, clientId); - Response rawResponse; - try { - rawResponse = httpClient.execute(new FormRequest(endpoints.getTokenEndpoint(), params)); - } catch (IOException e) { - LOG.error( - "Failed to exchange ID token for access token at {}: {}", - endpoints.getTokenEndpoint(), - e.getMessage(), - e); - throw new DatabricksException( - String.format( - "Failed to exchange ID token for access token at %s: %s", - endpoints.getTokenEndpoint(), e.getMessage()), - e); - } - OAuthResponse response; try { - response = OBJECT_MAPPER.readValue(rawResponse.getBody(), OAuthResponse.class); - } catch (IOException e) { + response = + TokenEndpointClient.requestToken(this.httpClient, endpoints.getTokenEndpoint(), params); + } catch (DatabricksException e) { LOG.error( - "Failed to parse OAuth response from token endpoint {}: {}", + "OAuth token exchange failed for client ID '{}' at {}: {}", + this.clientId, endpoints.getTokenEndpoint(), e.getMessage(), e); - throw new DatabricksException( - String.format( - "Failed to parse OAuth response from token endpoint %s: %s", - endpoints.getTokenEndpoint(), e.getMessage())); + throw e; } - if (response.getErrorCode() != null) { - LOG.error( - "Token exchange failed with error: {} - {}", - response.getErrorCode(), - response.getErrorSummary()); - throw new IllegalArgumentException( - String.format( - "Token exchange failed with error: %s - %s", - response.getErrorCode(), response.getErrorSummary())); - } LocalDateTime expiry = LocalDateTime.now().plusSeconds(response.getExpiresIn()); return new Token( response.getAccessToken(), response.getTokenType(), response.getRefreshToken(), expiry); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ErrorTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ErrorTokenSource.java new file mode 100644 index 000000000..0add3d9c6 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ErrorTokenSource.java @@ -0,0 +1,33 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.DatabricksException; +import java.util.Objects; + +/** + * A TokenSource implementation that always throws an error when attempting to get a token. This is + * used when the header factory is not an OAuthHeaderFactory. + */ +public class ErrorTokenSource implements TokenSource { + private final String errorMessage; + + /** + * Constructs a new ErrorTokenSource with the specified error message. + * + * @param errorMessage The error message that will be thrown when attempting to get a token + * @throws NullPointerException if errorMessage is null + */ + public ErrorTokenSource(String errorMessage) { + this.errorMessage = Objects.requireNonNull(errorMessage, "errorMessage cannot be null"); + } + + /** + * Always throws a DatabricksException with the configured error message. + * + * @return never returns normally, always throws an exception + * @throws DatabricksException with the configured error message + */ + @Override + public Token getToken() { + throw new DatabricksException(errorMessage); + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java index b8aa4c66f..7bae60022 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/ExternalBrowserCredentialsProvider.java @@ -3,7 +3,6 @@ import com.databricks.sdk.core.CredentialsProvider; import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.core.DatabricksException; -import com.databricks.sdk.core.HeaderFactory; import java.io.IOException; import java.nio.file.Path; import java.util.Objects; @@ -44,7 +43,7 @@ public String authType() { } @Override - public HeaderFactory configure(DatabricksConfig config) { + public OAuthHeaderFactory configure(DatabricksConfig config) { if (config.getHost() == null || !Objects.equals(config.getAuthType(), "external-browser")) { return null; } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthHeaderFactory.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthHeaderFactory.java new file mode 100644 index 000000000..614614c55 --- /dev/null +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthHeaderFactory.java @@ -0,0 +1,59 @@ +package com.databricks.sdk.core.oauth; + +import com.databricks.sdk.core.HeaderFactory; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Supplier; + +/** + * Factory interface for creating OAuth authentication headers. This interface combines the + * functionality of {@link HeaderFactory} and {@link TokenSource}. + */ +public interface OAuthHeaderFactory extends HeaderFactory, TokenSource { + /** + * Creates an OAuthHeaderFactory from separate token and header suppliers. This allows for custom + * header generation beyond just the Authorization header. + * + * @param tokenSupplier A supplier that provides OAuth tokens + * @param headerSupplier A supplier that provides a map of header name-value pairs + * @return A new OAuthHeaderFactory instance that uses the provided suppliers + */ + static OAuthHeaderFactory fromSuppliers( + Supplier tokenSupplier, Supplier> headerSupplier) { + return new OAuthHeaderFactory() { + @Override + public Map headers() { + return headerSupplier.get(); + } + + @Override + public Token getToken() { + return tokenSupplier.get(); + } + }; + } + + /** + * Creates an OAuthHeaderFactory from a TokenSource. This is a convenience method for the common + * case where headers are derived from the token. + * + * @param tokenSource The source of OAuth tokens + * @return A new OAuthHeaderFactory instance that uses the provided token source + */ + static OAuthHeaderFactory fromTokenSource(TokenSource tokenSource) { + return new OAuthHeaderFactory() { + @Override + public Token getToken() { + return tokenSource.getToken(); + } + + @Override + public Map headers() { + Token token = tokenSource.getToken(); + Map headers = new HashMap<>(); + headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken()); + return headers; + } + }; + } +} diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthM2MServicePrincipalCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthM2MServicePrincipalCredentialsProvider.java index 9b389cb34..058fc268c 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthM2MServicePrincipalCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/OAuthM2MServicePrincipalCredentialsProvider.java @@ -1,18 +1,14 @@ package com.databricks.sdk.core.oauth; import com.databricks.sdk.core.*; -import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; import java.util.Collections; -import java.util.HashMap; -import java.util.Map; /** * Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request, if * /oidc/.well-known/oauth-authorization-server is available on the given host. */ public class OAuthM2MServicePrincipalCredentialsProvider implements CredentialsProvider { - private final ObjectMapper mapper = new ObjectMapper(); @Override public String authType() { @@ -20,7 +16,7 @@ public String authType() { } @Override - public HeaderFactory configure(DatabricksConfig config) { + public OAuthHeaderFactory configure(DatabricksConfig config) { if (config.getClientId() == null || config.getClientSecret() == null || config.getHost() == null) { @@ -41,12 +37,7 @@ public HeaderFactory configure(DatabricksConfig config) { .withAuthParameterPosition(AuthParameterPosition.HEADER) .build(); - return () -> { - Token token = tokenSource.getToken(); - Map headers = new HashMap<>(); - headers.put("Authorization", token.getTokenType() + " " + token.getAccessToken()); - return headers; - }; + return OAuthHeaderFactory.fromTokenSource(tokenSource); } catch (IOException e) { // TODO: Log exception throw new DatabricksException("Unable to fetch OIDC endpoint: " + e.getMessage(), e); diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java index 9114b6d6c..4d2d512e3 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/SessionCredentials.java @@ -3,12 +3,10 @@ import com.databricks.sdk.core.CredentialsProvider; import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.core.DatabricksException; -import com.databricks.sdk.core.HeaderFactory; import com.databricks.sdk.core.http.HttpClient; import java.io.Serializable; import java.util.HashMap; import java.util.Map; -import org.apache.http.HttpHeaders; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -30,13 +28,8 @@ public String authType() { } @Override - public HeaderFactory configure(DatabricksConfig config) { - return () -> { - Map headers = new HashMap<>(); - headers.put( - HttpHeaders.AUTHORIZATION, getToken().getTokenType() + " " + getToken().getAccessToken()); - return headers; - }; + public OAuthHeaderFactory configure(DatabricksConfig config) { + return OAuthHeaderFactory.fromTokenSource(this); } static class Builder { diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java index 5b098d076..9a341b901 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/oauth/TokenSourceCredentialsProvider.java @@ -2,9 +2,6 @@ import com.databricks.sdk.core.CredentialsProvider; import com.databricks.sdk.core.DatabricksConfig; -import com.databricks.sdk.core.HeaderFactory; -import java.util.HashMap; -import java.util.Map; /** * A credentials provider that uses a TokenSource to obtain and manage authentication tokens. This @@ -39,18 +36,12 @@ public TokenSourceCredentialsProvider(TokenSource tokenSource, String authType) * acquisition fails. */ @Override - public HeaderFactory configure(DatabricksConfig config) { + public OAuthHeaderFactory configure(DatabricksConfig config) { try { // Validate that we can get a token before returning a HeaderFactory tokenSource.getToken().getAccessToken(); - return () -> { - Map headers = new HashMap<>(); - // Some TokenSource implementations cache tokens internally, so an additional getToken() - // call is not costly - headers.put("Authorization", "Bearer " + tokenSource.getToken().getAccessToken()); - return headers; - }; + return OAuthHeaderFactory.fromTokenSource(tokenSource); } catch (Exception e) { return null; } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/service/serving/ServingEndpointsDataPlaneAPI.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/service/serving/ServingEndpointsDataPlaneAPI.java index 05aef2cb6..3afdc690c 100755 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/service/serving/ServingEndpointsDataPlaneAPI.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/service/serving/ServingEndpointsDataPlaneAPI.java @@ -2,6 +2,7 @@ package com.databricks.sdk.service.serving; import com.databricks.sdk.core.ApiClient; +import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.support.Generated; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -17,8 +18,9 @@ public class ServingEndpointsDataPlaneAPI { private final ServingEndpointsDataPlaneService impl; /** Regular-use constructor */ - public ServingEndpointsDataPlaneAPI(ApiClient apiClient) { - impl = new ServingEndpointsDataPlaneImpl(apiClient); + public ServingEndpointsDataPlaneAPI( + ApiClient apiClient, DatabricksConfig config, ServingEndpointsAPI servingEndpointsAPI) { + impl = new ServingEndpointsDataPlaneImpl(apiClient, config, servingEndpointsAPI); } /** Constructor for mocks */ diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/service/serving/ServingEndpointsDataPlaneImpl.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/service/serving/ServingEndpointsDataPlaneImpl.java index 2dabe61d2..46e61fdc8 100755 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/service/serving/ServingEndpointsDataPlaneImpl.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/service/serving/ServingEndpointsDataPlaneImpl.java @@ -2,29 +2,65 @@ package com.databricks.sdk.service.serving; import com.databricks.sdk.core.ApiClient; +import com.databricks.sdk.core.DatabricksConfig; import com.databricks.sdk.core.DatabricksException; import com.databricks.sdk.core.http.Request; +import com.databricks.sdk.core.http.RequestOptions; +import com.databricks.sdk.core.oauth.DataPlaneTokenSource; +import com.databricks.sdk.core.oauth.Token; import com.databricks.sdk.support.Generated; import java.io.IOException; +import java.util.concurrent.ConcurrentHashMap; /** Package-local implementation of ServingEndpointsDataPlane */ @Generated class ServingEndpointsDataPlaneImpl implements ServingEndpointsDataPlaneService { private final ApiClient apiClient; + private final ServingEndpointsAPI servingEndpointsAPI; + private final DataPlaneTokenSource dataPlaneTokenSource; + private final ConcurrentHashMap infos; - public ServingEndpointsDataPlaneImpl(ApiClient apiClient) { + public ServingEndpointsDataPlaneImpl( + ApiClient apiClient, DatabricksConfig config, ServingEndpointsAPI servingEndpointsAPI) { this.apiClient = apiClient; + this.servingEndpointsAPI = servingEndpointsAPI; + this.dataPlaneTokenSource = + new DataPlaneTokenSource( + apiClient.getHttpClient(), config.getTokenSource(), config.getHost()); + this.infos = new ConcurrentHashMap<>(); + } + + private DataPlaneInfo dataPlaneInfoQuery(QueryEndpointInput request) { + String key = + String.format( + "Query/%s", String.join("/", new String[] {String.valueOf(request.getName())})); + + return infos.computeIfAbsent( + key, + k -> { + ServingEndpointDetailed response = + servingEndpointsAPI.get(new GetServingEndpointRequest().setName(request.getName())); + return response.getDataPlaneInfo().getQueryInfo(); + }); } @Override public QueryEndpointResponse query(QueryEndpointInput request) { - String path = String.format("/serving-endpoints/%s/invocations", request.getName()); + DataPlaneInfo dataPlaneInfo = dataPlaneInfoQuery(request); + String path = dataPlaneInfo.getEndpointUrl(); + Token token = dataPlaneTokenSource.getToken(path, dataPlaneInfo.getAuthorizationDetails()); + try { Request req = new Request("POST", path, apiClient.serialize(request)); ApiClient.setQuery(req, request); req.withHeader("Accept", "application/json"); req.withHeader("Content-Type", "application/json"); - return apiClient.execute(req, QueryEndpointResponse.class); + RequestOptions options = + new RequestOptions() + .withAuthorization(token.getTokenType() + " " + token.getAccessToken()) + .withUrl(path); + + return apiClient.execute(req, QueryEndpointResponse.class, options); } catch (IOException e) { throw new DatabricksException("IO error: " + e.getMessage(), e); } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java index e552a1427..38b6fcd9c 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksConfigTest.java @@ -1,11 +1,18 @@ package com.databricks.sdk.core; import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; import com.databricks.sdk.core.commons.CommonsHttpClient; +import com.databricks.sdk.core.http.HttpClient; +import com.databricks.sdk.core.oauth.ErrorTokenSource; +import com.databricks.sdk.core.oauth.OAuthHeaderFactory; import com.databricks.sdk.core.oauth.OpenIDConnectEndpoints; +import com.databricks.sdk.core.oauth.Token; +import com.databricks.sdk.core.oauth.TokenSource; import com.databricks.sdk.core.utils.Environment; import java.io.IOException; +import java.time.LocalDateTime; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -195,4 +202,53 @@ public void testClone() { assert newWorkspaceConfig.getClientId().equals("my-client-id"); assert newWorkspaceConfig.getClientSecret().equals("my-client-secret"); } + + @Test + public void testGetTokenSourceWithNonOAuth() { + HttpClient httpClient = mock(HttpClient.class); + HeaderFactory mockHeaderFactory = mock(HeaderFactory.class); + CredentialsProvider mockProvider = mock(CredentialsProvider.class); + when(mockProvider.authType()).thenReturn("test"); + when(mockProvider.configure(any())).thenReturn(mockHeaderFactory); + + DatabricksConfig config = + new DatabricksConfig() + .setHost("https://test.databricks.com") + .setHttpClient(httpClient) + .setCredentialsProvider(mockProvider); + + // This will set the headerFactory internally + config.authenticate(); + + TokenSource tokenSource = config.getTokenSource(); + assertTrue(tokenSource instanceof ErrorTokenSource); + DatabricksException exception = + assertThrows(DatabricksException.class, () -> tokenSource.getToken()); + assertEquals("OAuth Token not supported for current auth type test", exception.getMessage()); + } + + @Test + public void testGetTokenSourceWithOAuth() { + HttpClient httpClient = mock(HttpClient.class); + TokenSource mockTokenSource = mock(TokenSource.class); + when(mockTokenSource.getToken()) + .thenReturn(new Token("test-token", "Bearer", LocalDateTime.now().plusHours(1))); + OAuthHeaderFactory mockHeaderFactory = OAuthHeaderFactory.fromTokenSource(mockTokenSource); + CredentialsProvider mockProvider = mock(CredentialsProvider.class); + when(mockProvider.authType()).thenReturn("test"); + when(mockProvider.configure(any())).thenReturn(mockHeaderFactory); + + DatabricksConfig config = + new DatabricksConfig() + .setHost("https://test.databricks.com") + .setHttpClient(httpClient) + .setCredentialsProvider(mockProvider); + + // This will set the headerFactory internally + config.authenticate(); + + TokenSource tokenSource = config.getTokenSource(); + assertFalse(tokenSource instanceof ErrorTokenSource); + assertEquals(tokenSource.getToken().getAccessToken(), "test-token"); + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/http/RequestOptionsTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/http/RequestOptionsTest.java new file mode 100644 index 000000000..2408dbc99 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/http/RequestOptionsTest.java @@ -0,0 +1,89 @@ +package com.databricks.sdk.core.http; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +public class RequestOptionsTest { + private static final String DEFAULT_METHOD = "GET"; + private static final String DEFAULT_URL = "https://example.com"; + private static final String DEFAULT_AUTH = "Bearer token123"; + private static final String DEFAULT_USER_AGENT = "TestAgent/1.0"; + private static final String NEW_URL = "https://new-url.com/api/v1"; + private static final String NEW_AUTH = "Bearer token456"; + private static final String NEW_USER_AGENT = "NewAgent/1.0"; + + private static Request createDefaultRequest() { + return new Request(DEFAULT_METHOD, DEFAULT_URL) + .withHeader("Authorization", DEFAULT_AUTH) + .withHeader("User-Agent", DEFAULT_USER_AGENT); + } + + private static Stream provideTestCases() { + return Stream.of( + // Default constructor test + Arguments.of( + "Default constructor should not modify request", + new RequestOptions(), + DEFAULT_AUTH, + DEFAULT_URL, + DEFAULT_USER_AGENT), + // Authorization header test + Arguments.of( + "Authorization header should be updated", + new RequestOptions().withAuthorization(NEW_AUTH), + NEW_AUTH, + DEFAULT_URL, + DEFAULT_USER_AGENT), + // URL test + Arguments.of( + "URL should be updated", + new RequestOptions().withUrl(NEW_URL), + DEFAULT_AUTH, + NEW_URL, + DEFAULT_USER_AGENT), + // User-Agent test + Arguments.of( + "User-Agent header should be updated", + new RequestOptions().withUserAgent(NEW_USER_AGENT), + DEFAULT_AUTH, + DEFAULT_URL, + NEW_USER_AGENT), + // Multiple options test + Arguments.of( + "Multiple options should be applied", + new RequestOptions() + .withAuthorization(NEW_AUTH) + .withUrl(NEW_URL) + .withUserAgent(NEW_USER_AGENT), + NEW_AUTH, + NEW_URL, + NEW_USER_AGENT)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideTestCases") + public void testRequestOptions( + String testName, + RequestOptions options, + String expectedAuth, + String expectedUrl, + String expectedUserAgent) { + + Request originalRequest = createDefaultRequest(); + Request result = options.applyOptions(originalRequest); + + // Verify method is unchanged + assertEquals(DEFAULT_METHOD, result.getMethod()); + + // Verify URL + assertEquals(expectedUrl, result.getUrl()); + + // Verify headers + assertEquals(expectedAuth, result.getHeaders().get("Authorization")); + assertEquals(expectedUserAgent, result.getHeaders().get("User-Agent")); + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProviderTest.java index 10b9c0ecc..f67beceb4 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/AzureGithubOidcCredentialsProviderTest.java @@ -25,7 +25,7 @@ public class AzureGithubOidcCredentialsProviderTest { private static final String OAUTH_RESPONSE = new JSONObject() .put("access_token", TOKEN) - .put("token_type", "token-type") + .put("token_type", "Bearer") .put("expires_in", 360) .toString(); diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java index 8d7da8d3a..8217179f2 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/DatabricksOAuthTokenSourceTest.java @@ -15,7 +15,6 @@ import java.util.HashMap; import java.util.Map; import java.util.stream.Stream; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mockito; @@ -35,45 +34,42 @@ class DatabricksOAuthTokenSourceTest { private static final String TEST_AUDIENCE = "test-audience"; private static final String TEST_ACCOUNT_ID = "test-account-id"; - // Error message constants - private static final String ERROR_NULL = "Required parameter '%s' cannot be null"; - private static final String ERROR_EMPTY = "Required parameter '%s' cannot be empty"; - - private IDTokenSource mockIdTokenSource; - - @BeforeEach - void setUp() { - mockIdTokenSource = Mockito.mock(IDTokenSource.class); - IDToken idToken = new IDToken(TEST_ID_TOKEN); - when(mockIdTokenSource.getIDToken(any())).thenReturn(idToken); - } - /** * Test case data for parameterized token source tests. Each case defines a specific OAuth token * exchange scenario. */ private static class TestCase { final String name; // Descriptive name of the test case + final String clientId; // Client ID to use + final String host; // Host to use + final OpenIDConnectEndpoints endpoints; // OIDC endpoints + final IDTokenSource idTokenSource; // ID token source + final HttpClient httpClient; // HTTP client final String audience; // Custom audience value if provided final String accountId; // Account ID if provided final String expectedAudience; // Expected audience used in token exchange - final HttpClient mockHttpClient; // Pre-configured mock HTTP client final Class expectedException; // Expected exception type if any TestCase( String name, + String clientId, + String host, + OpenIDConnectEndpoints endpoints, + IDTokenSource idTokenSource, + HttpClient httpClient, String audience, String accountId, String expectedAudience, - int statusCode, - Object responseBody, - HttpClient mockHttpClient, Class expectedException) { this.name = name; + this.clientId = clientId; + this.host = host; + this.endpoints = endpoints; + this.idTokenSource = idTokenSource; + this.httpClient = httpClient; this.audience = audience; this.accountId = accountId; this.expectedAudience = expectedAudience; - this.mockHttpClient = mockHttpClient; this.expectedException = expectedException; } @@ -87,20 +83,27 @@ public String toString() { * Provides test cases for OAuth token exchange scenarios. Includes success cases with different * audience configurations and various error cases. */ - private static Stream provideTestCases() { - try { - // Success response with valid token data - Map successResponse = new HashMap<>(); - successResponse.put("access_token", TOKEN); - successResponse.put("token_type", TOKEN_TYPE); - successResponse.put("refresh_token", REFRESH_TOKEN); - successResponse.put("expires_in", EXPIRES_IN); + private static Stream provideTestCases() throws MalformedURLException { + // Create valid components for reuse + OpenIDConnectEndpoints testEndpoints = + new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); + IDTokenSource testIdTokenSource = Mockito.mock(IDTokenSource.class); + IDToken idToken = new IDToken(TEST_ID_TOKEN); + when(testIdTokenSource.getIDToken(any())).thenReturn(idToken); + + // Create success response for token exchange tests + Map successResponse = new HashMap<>(); + successResponse.put("access_token", TOKEN); + successResponse.put("token_type", TOKEN_TYPE); + successResponse.put("refresh_token", REFRESH_TOKEN); + successResponse.put("expires_in", EXPIRES_IN); - // Error response for invalid requests - Map errorResponse = new HashMap<>(); - errorResponse.put("error", "invalid_request"); - errorResponse.put("error_description", "Invalid client ID"); + // Create error response for invalid requests + Map errorResponse = new HashMap<>(); + errorResponse.put("error", "invalid_request"); + errorResponse.put("error_description", "Invalid client ID"); + try { ObjectMapper mapper = new ObjectMapper(); final String errorJson = mapper.writeValueAsString(errorResponse); final String successJson = mapper.writeValueAsString(successResponse); @@ -115,71 +118,162 @@ private static Stream provideTestCases() { FormRequest expectedRequest = new FormRequest(TEST_TOKEN_ENDPOINT, formParams); return Stream.of( - // Success cases with different audience configurations + // Token exchange test cases new TestCase( "Default audience from token endpoint", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), null, null, TEST_TOKEN_ENDPOINT, - 200, - successResponse, - createMockHttpClient(expectedRequest, 200, successJson), null), new TestCase( "Custom audience provided", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), TEST_AUDIENCE, null, TEST_AUDIENCE, - 200, - successResponse, - createMockHttpClient(expectedRequest, 200, successJson), null), new TestCase( "Custom audience takes precedence over account ID", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), TEST_AUDIENCE, TEST_ACCOUNT_ID, TEST_AUDIENCE, - 200, - successResponse, - createMockHttpClient(expectedRequest, 200, successJson), null), new TestCase( "Account ID used as audience when no custom audience", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), null, TEST_ACCOUNT_ID, TEST_ACCOUNT_ID, - 200, - successResponse, - createMockHttpClient(expectedRequest, 200, successJson), null), - // Error cases new TestCase( "Invalid request returns 400", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 400, errorJson), null, null, TEST_TOKEN_ENDPOINT, - 400, - errorJson, - createMockHttpClient(expectedRequest, 400, errorJson), - IllegalArgumentException.class), + DatabricksException.class), new TestCase( "Network error during token exchange", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClientWithError(expectedRequest), null, null, TEST_TOKEN_ENDPOINT, - 0, - null, - createMockHttpClientWithError(expectedRequest), DatabricksException.class), new TestCase( "Invalid JSON response from server", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, "invalid json"), null, null, TEST_TOKEN_ENDPOINT, - 200, - "invalid json", - createMockHttpClient(expectedRequest, 200, "invalid json"), - DatabricksException.class)); + DatabricksException.class), + // Parameter validation test cases + new TestCase( + "Null client ID", + null, + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + NullPointerException.class), + new TestCase( + "Empty client ID", + "", + TEST_HOST, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + IllegalArgumentException.class), + new TestCase( + "Null host", + TEST_CLIENT_ID, + null, + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + NullPointerException.class), + new TestCase( + "Empty host", + TEST_CLIENT_ID, + "", + testEndpoints, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + IllegalArgumentException.class), + new TestCase( + "Null endpoints", + TEST_CLIENT_ID, + TEST_HOST, + null, + testIdTokenSource, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + NullPointerException.class), + new TestCase( + "Null IDTokenSource", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + null, + createMockHttpClient(expectedRequest, 200, successJson), + null, + null, + null, + NullPointerException.class), + new TestCase( + "Null HttpClient", + TEST_CLIENT_ID, + TEST_HOST, + testEndpoints, + testIdTokenSource, + null, + null, + null, + null, + NullPointerException.class)); } catch (IOException e) { throw new RuntimeException("Failed to create test cases", e); } @@ -212,179 +306,34 @@ private static HttpClient createMockHttpClientWithError(FormRequest expectedRequ * Tests OAuth token exchange with various configurations and error scenarios. Verifies correct * audience selection, token exchange, and error handling. */ - @ParameterizedTest(name = "testTokenSource: {arguments}") + @ParameterizedTest(name = "{0}") @MethodSource("provideTestCases") void testTokenSource(TestCase testCase) { - try { - // Create token source with test configuration - OpenIDConnectEndpoints endpoints = - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); - - DatabricksOAuthTokenSource.Builder builder = - new DatabricksOAuthTokenSource.Builder( - TEST_CLIENT_ID, TEST_HOST, endpoints, mockIdTokenSource, testCase.mockHttpClient); - - builder.audience(testCase.audience).accountId(testCase.accountId); - - DatabricksOAuthTokenSource tokenSource = builder.build(); - - if (testCase.expectedException != null) { - assertThrows(testCase.expectedException, () -> tokenSource.getToken()); - } else { - // Verify successful token exchange - Token token = tokenSource.getToken(); - assertEquals(TOKEN, token.getAccessToken()); - assertEquals(TOKEN_TYPE, token.getTokenType()); - assertEquals(REFRESH_TOKEN, token.getRefreshToken()); - assertFalse(token.isExpired()); + DatabricksOAuthTokenSource.Builder builder = + new DatabricksOAuthTokenSource.Builder( + testCase.clientId, + testCase.host, + testCase.endpoints, + testCase.idTokenSource, + testCase.httpClient); - // Verify correct audience was used - verify(mockIdTokenSource).getIDToken(testCase.expectedAudience); - } - } catch (IOException e) { - throw new RuntimeException("Test failed", e); - } - } + builder.audience(testCase.audience); + builder.accountId(testCase.accountId); - /** - * Test case data for parameter validation tests. Each case defines a specific validation - * scenario. - */ - private static class ValidationTestCase { - final String name; - final String clientId; - final String host; - final OpenIDConnectEndpoints endpoints; - final IDTokenSource idTokenSource; - final HttpClient httpClient; - final String expectedFieldName; - final boolean isNullTest; + DatabricksOAuthTokenSource tokenSource = builder.build(); - ValidationTestCase( - String name, - String clientId, - String host, - OpenIDConnectEndpoints endpoints, - IDTokenSource idTokenSource, - HttpClient httpClient, - String expectedFieldName, - boolean isNullTest) { - this.name = name; - this.clientId = clientId; - this.host = host; - this.endpoints = endpoints; - this.idTokenSource = idTokenSource; - this.httpClient = httpClient; - this.expectedFieldName = expectedFieldName; - this.isNullTest = isNullTest; - } + if (testCase.expectedException != null) { + assertThrows(testCase.expectedException, () -> tokenSource.getToken()); + } else { + // Verify successful token exchange + Token token = tokenSource.getToken(); + assertEquals(TOKEN, token.getAccessToken()); + assertEquals(TOKEN_TYPE, token.getTokenType()); + assertEquals(REFRESH_TOKEN, token.getRefreshToken()); + assertFalse(token.isExpired()); - @Override - public String toString() { - return name; + // Verify correct audience was used + verify(testCase.idTokenSource, atLeastOnce()).getIDToken(testCase.expectedAudience); } } - - private static Stream provideValidationTestCases() - throws MalformedURLException { - OpenIDConnectEndpoints validEndpoints = - new OpenIDConnectEndpoints(TEST_TOKEN_ENDPOINT, TEST_AUTHORIZATION_ENDPOINT); - HttpClient validHttpClient = Mockito.mock(HttpClient.class); - IDTokenSource validIdTokenSource = Mockito.mock(IDTokenSource.class); - - return Stream.of( - // Client ID validation - new ValidationTestCase( - "Null client ID", - null, - TEST_HOST, - validEndpoints, - validIdTokenSource, - validHttpClient, - "ClientID", - true), - new ValidationTestCase( - "Empty client ID", - "", - TEST_HOST, - validEndpoints, - validIdTokenSource, - validHttpClient, - "ClientID", - false), - // Host validation - new ValidationTestCase( - "Null host", - TEST_CLIENT_ID, - null, - validEndpoints, - validIdTokenSource, - validHttpClient, - "Host", - true), - new ValidationTestCase( - "Empty host", - TEST_CLIENT_ID, - "", - validEndpoints, - validIdTokenSource, - validHttpClient, - "Host", - false), - // Endpoints validation - new ValidationTestCase( - "Null endpoints", - TEST_CLIENT_ID, - TEST_HOST, - null, - validIdTokenSource, - validHttpClient, - "Endpoints", - true), - // IDTokenSource validation - new ValidationTestCase( - "Null IDTokenSource", - TEST_CLIENT_ID, - TEST_HOST, - validEndpoints, - null, - validHttpClient, - "IDTokenSource", - true), - // HttpClient validation - new ValidationTestCase( - "Null HttpClient", - TEST_CLIENT_ID, - TEST_HOST, - validEndpoints, - validIdTokenSource, - null, - "HttpClient", - true)); - } - - /** - * Tests validation of required fields in the token source using parameterized test cases. - * Verifies that null or empty values for required fields cause getToken() to throw - * IllegalArgumentException with specific error messages. - */ - @ParameterizedTest(name = "testParameterValidation: {0}") - @MethodSource("provideValidationTestCases") - void testParameterValidation(ValidationTestCase testCase) { - DatabricksOAuthTokenSource tokenSource = - new DatabricksOAuthTokenSource.Builder( - testCase.clientId, - testCase.host, - testCase.endpoints, - testCase.idTokenSource, - testCase.httpClient) - .build(); - - IllegalArgumentException exception = - assertThrows(IllegalArgumentException.class, () -> tokenSource.getToken()); - - String expectedMessage = - String.format(testCase.isNullTest ? ERROR_NULL : ERROR_EMPTY, testCase.expectedFieldName); - assertEquals(expectedMessage, exception.getMessage()); - } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ErrorTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ErrorTokenSourceTest.java new file mode 100644 index 000000000..5bf66f19c --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/ErrorTokenSourceTest.java @@ -0,0 +1,34 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; + +import com.databricks.sdk.core.DatabricksException; +import org.junit.jupiter.api.Test; + +public class ErrorTokenSourceTest { + + @Test + public void testGetTokenThrowsException() { + String errorMessage = "Test error message"; + ErrorTokenSource tokenSource = new ErrorTokenSource(errorMessage); + + DatabricksException exception = + assertThrows( + DatabricksException.class, + () -> tokenSource.getToken(), + "Expected getToken() to throw DatabricksException"); + + assertEquals( + errorMessage, + exception.getMessage(), + "Exception message should match the one provided in constructor"); + } + + @Test + public void testConstructorWithNullErrorMessage() { + assertThrows( + NullPointerException.class, + () -> new ErrorTokenSource(null), + "Expected constructor to throw NullPointerException when error message is null"); + } +} diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/OAuthHeaderFactoryTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/OAuthHeaderFactoryTest.java new file mode 100644 index 000000000..d0530b2c1 --- /dev/null +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/oauth/OAuthHeaderFactoryTest.java @@ -0,0 +1,96 @@ +package com.databricks.sdk.core.oauth; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import java.time.LocalDateTime; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Supplier; +import java.util.stream.Stream; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class OAuthHeaderFactoryTest { + + private static final String TOKEN_TYPE = "Bearer"; + private static final String TOKEN_VALUE = "test-token"; + + @Mock private TokenSource tokenSource; + + private static Stream provideTokenSourceTestCases() { + LocalDateTime expiry = LocalDateTime.now().plusHours(1); + Token token = new Token(TOKEN_VALUE, TOKEN_TYPE, expiry); + + return Stream.of( + Arguments.of( + "Standard token source", + token, + Collections.singletonMap("Authorization", TOKEN_TYPE + " " + TOKEN_VALUE)), + Arguments.of( + "Token with custom type", + new Token(TOKEN_VALUE, "Custom", expiry), + Collections.singletonMap("Authorization", "Custom " + TOKEN_VALUE))); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideTokenSourceTestCases") + public void testFromTokenSourceFactoryMethod( + String testName, Token token, Map expectedHeaders) { + when(tokenSource.getToken()).thenReturn(token); + + OAuthHeaderFactory factory = OAuthHeaderFactory.fromTokenSource(tokenSource); + + assertNotNull(factory, "Factory should not be null"); + + Token actualToken = factory.getToken(); + assertEquals(token, actualToken, "Factory should return the same token as the source"); + + Map headers = factory.headers(); + assertEquals(expectedHeaders, headers, "Factory should generate correct headers"); + } + + private static Stream provideSuppliersTestCases() { + LocalDateTime expiry = LocalDateTime.now().plusHours(1); + Token token = new Token(TOKEN_VALUE, TOKEN_TYPE, expiry); + + Map standardHeaders = new HashMap<>(); + standardHeaders.put("Authorization", TOKEN_TYPE + " " + TOKEN_VALUE); + standardHeaders.put("Content-Type", "application/json"); + + Map multipleHeaders = new HashMap<>(); + multipleHeaders.put("Authorization", TOKEN_TYPE + " " + TOKEN_VALUE); + multipleHeaders.put("X-Custom-Header", "custom-value"); + multipleHeaders.put("Accept", "application/json"); + + return Stream.of( + Arguments.of("Standard suppliers", token, standardHeaders), + Arguments.of("Empty headers", token, new HashMap<>()), + Arguments.of("Multiple custom headers", token, multipleHeaders)); + } + + @ParameterizedTest(name = "{0}") + @MethodSource("provideSuppliersTestCases") + public void testFromSuppliersFactoryMethod( + String testName, Token token, Map expectedHeaders) { + Supplier tokenSupplier = () -> token; + Supplier> headerSupplier = () -> new HashMap<>(expectedHeaders); + + OAuthHeaderFactory factory = OAuthHeaderFactory.fromSuppliers(tokenSupplier, headerSupplier); + + assertNotNull(factory, "Factory should not be null"); + + Token actualToken = factory.getToken(); + assertEquals(token, actualToken, "Factory should return the same token as the supplier"); + + Map actualHeaders = factory.headers(); + assertEquals( + expectedHeaders, actualHeaders, "Factory should return the same headers as the supplier"); + } +}