Skip to content

Commit 5203e57

Browse files
committed
Add GithubIDTokenSource and TokenSourceCredentialsProvider
1 parent efe5aa4 commit 5203e57

6 files changed

Lines changed: 410 additions & 72 deletions

File tree

databricks-sdk-java/src/main/java/com/databricks/sdk/core/DefaultCredentialsProvider.java

Lines changed: 75 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,56 +2,44 @@
22

33
import com.databricks.sdk.core.oauth.*;
44
import java.util.ArrayList;
5-
import java.util.Arrays;
65
import java.util.List;
76
import org.slf4j.Logger;
87
import org.slf4j.LoggerFactory;
98

109
public class DefaultCredentialsProvider implements CredentialsProvider {
1110
private static final Logger LOG = LoggerFactory.getLogger(DefaultCredentialsProvider.class);
1211

13-
private static final List<Class<?>> providerClasses =
14-
Arrays.asList(
15-
PatCredentialsProvider.class,
16-
BasicCredentialsProvider.class,
17-
OAuthM2MServicePrincipalCredentialsProvider.class,
18-
GithubOidcCredentialsProvider.class,
19-
AzureGithubOidcCredentialsProvider.class,
20-
AzureServicePrincipalCredentialsProvider.class,
21-
AzureCliCredentialsProvider.class,
22-
ExternalBrowserCredentialsProvider.class,
23-
DatabricksCliCredentialsProvider.class,
24-
NotebookNativeCredentialsProvider.class,
25-
GoogleCredentialsCredentialsProvider.class,
26-
GoogleIdCredentialsProvider.class);
27-
28-
private final List<CredentialsProvider> providers;
12+
private List<CredentialsProvider> providers = new ArrayList<>();
2913

3014
private String authType = "default";
3115

32-
public String authType() {
33-
return authType;
34-
}
16+
private static class NamedIDTokenSource {
17+
private final String name;
18+
private final IDTokenSource idTokenSource;
3519

36-
public DefaultCredentialsProvider() {
37-
providers = new ArrayList<>();
38-
for (Class<?> clazz : providerClasses) {
39-
try {
40-
providers.add((CredentialsProvider) clazz.newInstance());
41-
} catch (NoClassDefFoundError | InstantiationException | IllegalAccessException e) {
42-
LOG.warn(
43-
"Failed to instantiate credentials provider: "
44-
+ clazz.getName()
45-
+ ", skipping. Cause: "
46-
+ e.getClass().getCanonicalName()
47-
+ ": "
48-
+ e.getMessage());
49-
}
20+
public NamedIDTokenSource(String name, IDTokenSource idTokenSource) {
21+
this.name = name;
22+
this.idTokenSource = idTokenSource;
23+
}
24+
25+
public String getName() {
26+
return name;
27+
}
28+
29+
public IDTokenSource getIdTokenSource() {
30+
return idTokenSource;
5031
}
5132
}
5233

34+
public DefaultCredentialsProvider() {}
35+
36+
public String authType() {
37+
return authType;
38+
}
39+
5340
@Override
5441
public synchronized HeaderFactory configure(DatabricksConfig config) {
42+
addDefaultCredentialsProviders(config);
5543
for (CredentialsProvider provider : providers) {
5644
if (config.getAuthType() != null
5745
&& !config.getAuthType().isEmpty()
@@ -80,4 +68,57 @@ public synchronized HeaderFactory configure(DatabricksConfig config) {
8068
+ authFlowUrl
8169
+ " to configure credentials for your preferred authentication method");
8270
}
71+
72+
private void addOIDCCredentialsProviders(DatabricksConfig config) {
73+
OpenIDConnectEndpoints endpoints = null;
74+
try {
75+
endpoints = config.getOidcEndpoints();
76+
} catch (Exception e) {
77+
LOG.warn("Failed to get OpenID Connect endpoints", e);
78+
}
79+
80+
List<NamedIDTokenSource> namedIdTokenSources = new ArrayList<>();
81+
namedIdTokenSources.add(
82+
new NamedIDTokenSource(
83+
"github-oidc",
84+
new GithubIDTokenSource(
85+
config.getActionsIdTokenRequestUrl(),
86+
config.getActionsIdTokenRequestToken(),
87+
config.getHttpClient())));
88+
// Add new IDTokenSources and ID providers here. Example:
89+
// namedIdTokenSources.add(new NamedIDTokenSource("custom-oidc", new CustomIDTokenSource(...)));
90+
91+
for (NamedIDTokenSource namedIdTokenSource : namedIdTokenSources) {
92+
DatabricksOAuthTokenSource oauthTokenSource =
93+
new DatabricksOAuthTokenSource.Builder(
94+
config.getClientId(),
95+
config.getHost(),
96+
endpoints,
97+
namedIdTokenSource.getIdTokenSource(),
98+
config.getHttpClient())
99+
.audience(config.getTokenAudience())
100+
.accountId(config.isAccountClient() ? config.getAccountId() : null)
101+
.build();
102+
103+
providers.add(
104+
new TokenSourceCredentialsProvider(oauthTokenSource, namedIdTokenSource.getName()));
105+
}
106+
}
107+
108+
private void addDefaultCredentialsProviders(DatabricksConfig config) {
109+
providers.add(new PatCredentialsProvider());
110+
providers.add(new BasicCredentialsProvider());
111+
providers.add(new OAuthM2MServicePrincipalCredentialsProvider());
112+
113+
addOIDCCredentialsProviders(config);
114+
115+
providers.add(new AzureGithubOidcCredentialsProvider());
116+
providers.add(new AzureServicePrincipalCredentialsProvider());
117+
providers.add(new AzureCliCredentialsProvider());
118+
providers.add(new ExternalBrowserCredentialsProvider());
119+
providers.add(new DatabricksCliCredentialsProvider());
120+
providers.add(new NotebookNativeCredentialsProvider());
121+
providers.add(new GoogleCredentialsCredentialsProvider());
122+
providers.add(new GoogleIdCredentialsProvider());
123+
}
83124
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package com.databricks.sdk.core.oauth;
2+
3+
import com.databricks.sdk.core.DatabricksException;
4+
import com.databricks.sdk.core.http.HttpClient;
5+
import com.databricks.sdk.core.http.Request;
6+
import com.databricks.sdk.core.http.Response;
7+
import com.fasterxml.jackson.databind.ObjectMapper;
8+
import com.fasterxml.jackson.databind.node.ObjectNode;
9+
import com.google.common.base.Strings;
10+
import java.io.IOException;
11+
12+
/** GithubIDTokenSource retrieves JWT Tokens from GitHub Actions. */
13+
public class GithubIDTokenSource implements IDTokenSource {
14+
private final String actionsIDTokenRequestURL;
15+
private final String actionsIDTokenRequestToken;
16+
private final HttpClient httpClient;
17+
private final ObjectMapper mapper = new ObjectMapper();
18+
19+
/**
20+
* Constructs a new GithubIDTokenSource.
21+
*
22+
* @param actionsIDTokenRequestURL The URL to request the ID token from GitHub Actions.
23+
* @param actionsIDTokenRequestToken The token used to authenticate the request.
24+
* @param httpClient The HTTP client to use for making requests.
25+
*/
26+
public GithubIDTokenSource(
27+
String actionsIDTokenRequestURL, String actionsIDTokenRequestToken, HttpClient httpClient) {
28+
this.actionsIDTokenRequestURL = actionsIDTokenRequestURL;
29+
this.actionsIDTokenRequestToken = actionsIDTokenRequestToken;
30+
this.httpClient = httpClient;
31+
}
32+
33+
@Override
34+
public IDToken getIDToken(String audience) {
35+
if (Strings.isNullOrEmpty(actionsIDTokenRequestURL)) {
36+
throw new DatabricksException("Missing ActionsIDTokenRequestURL");
37+
}
38+
if (Strings.isNullOrEmpty(actionsIDTokenRequestToken)) {
39+
throw new DatabricksException("Missing ActionsIDTokenRequestToken");
40+
}
41+
if (httpClient == null) {
42+
throw new DatabricksException("HttpClient cannot be null");
43+
}
44+
45+
String requestUrl = actionsIDTokenRequestURL;
46+
if (!Strings.isNullOrEmpty(audience)) {
47+
requestUrl = String.format("%s&audience=%s", requestUrl, audience);
48+
}
49+
50+
Request req =
51+
new Request("GET", requestUrl)
52+
.withHeader("Authorization", "Bearer " + actionsIDTokenRequestToken);
53+
54+
Response resp;
55+
try {
56+
resp = httpClient.execute(req);
57+
} catch (IOException e) {
58+
throw new DatabricksException(
59+
"Failed to request ID token from " + requestUrl + ": " + e.getMessage(), e);
60+
}
61+
62+
if (resp.getStatusCode() != 200) {
63+
throw new DatabricksException(
64+
"Failed to request ID token: status code "
65+
+ resp.getStatusCode()
66+
+ ", response body: "
67+
+ resp.getBody().toString());
68+
}
69+
70+
ObjectNode jsonResp;
71+
try {
72+
jsonResp = mapper.readValue(resp.getBody(), ObjectNode.class);
73+
} catch (IOException e) {
74+
throw new DatabricksException(
75+
"Failed to request ID token: corrupted token: " + e.getMessage());
76+
}
77+
78+
if (!jsonResp.has("value")) {
79+
throw new DatabricksException("ID token response missing 'value' field");
80+
}
81+
82+
String tokenValue = jsonResp.get("value").textValue();
83+
if (Strings.isNullOrEmpty(tokenValue)) {
84+
throw new DatabricksException("Received empty ID token from GitHub Actions");
85+
}
86+
87+
return new IDToken(tokenValue);
88+
}
89+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package com.databricks.sdk.core.oauth;
2+
3+
import com.databricks.sdk.core.CredentialsProvider;
4+
import com.databricks.sdk.core.DatabricksConfig;
5+
import com.databricks.sdk.core.HeaderFactory;
6+
import java.util.HashMap;
7+
import java.util.Map;
8+
9+
/** Base class for token-based credentials providers. */
10+
public class TokenSourceCredentialsProvider implements CredentialsProvider {
11+
private final TokenSource tokenSource;
12+
private final String authType;
13+
14+
/**
15+
* Creates a new TokenSourceCredentialsProvider with the specified token source and auth type.
16+
*
17+
* @param tokenSource The token source to use for token exchange
18+
* @param authType The authentication type string
19+
*/
20+
public TokenSourceCredentialsProvider(TokenSource tokenSource, String authType) {
21+
this.tokenSource = tokenSource;
22+
this.authType = authType;
23+
}
24+
25+
@Override
26+
public HeaderFactory configure(DatabricksConfig config) {
27+
try {
28+
// Validate that we can get a token before returning the HeaderFactory
29+
String accessToken = tokenSource.getToken().getAccessToken();
30+
31+
return () -> {
32+
Map<String, String> headers = new HashMap<>();
33+
headers.put("Authorization", "Bearer " + accessToken);
34+
return headers;
35+
};
36+
} catch (Exception e) {
37+
return null;
38+
}
39+
}
40+
41+
@Override
42+
public String authType() {
43+
return authType;
44+
}
45+
}

databricks-sdk-java/src/test/java/com/databricks/sdk/DatabricksAuthManualTest.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@
22

33
import com.databricks.sdk.core.ConfigResolving;
44
import com.databricks.sdk.core.DatabricksConfig;
5+
import com.databricks.sdk.core.DummyHttpClient;
56
import com.databricks.sdk.core.utils.TestOSUtils;
67
import java.util.Map;
78
import org.junit.jupiter.api.Assertions;
89
import org.junit.jupiter.api.Test;
910

1011
public class DatabricksAuthManualTest implements ConfigResolving {
12+
private DatabricksConfig createConfigWithMockClient() {
13+
return new DatabricksConfig().setHttpClient(new DummyHttpClient());
14+
}
15+
1116
@Test
1217
void azureCliWorkspaceHeaderPresent() {
1318
StaticEnv env =
@@ -18,7 +23,7 @@ void azureCliWorkspaceHeaderPresent() {
1823
String azureWorkspaceResourceId =
1924
"/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123";
2025
DatabricksConfig config =
21-
new DatabricksConfig()
26+
createConfigWithMockClient()
2227
.setAuthType("azure-cli")
2328
.setHost("https://x")
2429
.setAzureWorkspaceResourceId(azureWorkspaceResourceId);
@@ -38,7 +43,7 @@ void azureCliUserWithManagementAccess() {
3843
String azureWorkspaceResourceId =
3944
"/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123";
4045
DatabricksConfig config =
41-
new DatabricksConfig()
46+
createConfigWithMockClient()
4247
.setAuthType("azure-cli")
4348
.setHost("https://x")
4449
.setAzureWorkspaceResourceId(azureWorkspaceResourceId);
@@ -58,7 +63,7 @@ void azureCliUserNoManagementAccess() {
5863
String azureWorkspaceResourceId =
5964
"/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123";
6065
DatabricksConfig config =
61-
new DatabricksConfig()
66+
createConfigWithMockClient()
6267
.setAuthType("azure-cli")
6368
.setHost("https://x")
6469
.setAzureWorkspaceResourceId(azureWorkspaceResourceId);

0 commit comments

Comments
 (0)