|
2 | 2 |
|
3 | 3 | import com.databricks.sdk.core.oauth.*; |
4 | 4 | import java.util.ArrayList; |
5 | | -import java.util.Arrays; |
6 | 5 | import java.util.List; |
7 | 6 | import org.slf4j.Logger; |
8 | 7 | import org.slf4j.LoggerFactory; |
9 | 8 |
|
10 | 9 | public class DefaultCredentialsProvider implements CredentialsProvider { |
11 | 10 | private static final Logger LOG = LoggerFactory.getLogger(DefaultCredentialsProvider.class); |
12 | 11 |
|
13 | | - private static final List<Class<?>> providerClasses = |
14 | | - Arrays.asList( |
15 | | - PatCredentialsProvider.class, |
16 | | - BasicCredentialsProvider.class, |
17 | | - OAuthM2MServicePrincipalCredentialsProvider.class, |
18 | | - GithubOidcCredentialsProvider.class, |
19 | | - AzureGithubOidcCredentialsProvider.class, |
20 | | - AzureServicePrincipalCredentialsProvider.class, |
21 | | - AzureCliCredentialsProvider.class, |
22 | | - ExternalBrowserCredentialsProvider.class, |
23 | | - DatabricksCliCredentialsProvider.class, |
24 | | - NotebookNativeCredentialsProvider.class, |
25 | | - GoogleCredentialsCredentialsProvider.class, |
26 | | - GoogleIdCredentialsProvider.class); |
27 | | - |
28 | | - private final List<CredentialsProvider> providers; |
| 12 | + private List<CredentialsProvider> providers = new ArrayList<>(); |
29 | 13 |
|
30 | 14 | private String authType = "default"; |
31 | 15 |
|
32 | | - public String authType() { |
33 | | - return authType; |
34 | | - } |
| 16 | + private static class NamedIDTokenSource { |
| 17 | + private final String name; |
| 18 | + private final IDTokenSource idTokenSource; |
35 | 19 |
|
36 | | - public DefaultCredentialsProvider() { |
37 | | - providers = new ArrayList<>(); |
38 | | - for (Class<?> clazz : providerClasses) { |
39 | | - try { |
40 | | - providers.add((CredentialsProvider) clazz.newInstance()); |
41 | | - } catch (NoClassDefFoundError | InstantiationException | IllegalAccessException e) { |
42 | | - LOG.warn( |
43 | | - "Failed to instantiate credentials provider: " |
44 | | - + clazz.getName() |
45 | | - + ", skipping. Cause: " |
46 | | - + e.getClass().getCanonicalName() |
47 | | - + ": " |
48 | | - + e.getMessage()); |
49 | | - } |
| 20 | + public NamedIDTokenSource(String name, IDTokenSource idTokenSource) { |
| 21 | + this.name = name; |
| 22 | + this.idTokenSource = idTokenSource; |
| 23 | + } |
| 24 | + |
| 25 | + public String getName() { |
| 26 | + return name; |
| 27 | + } |
| 28 | + |
| 29 | + public IDTokenSource getIdTokenSource() { |
| 30 | + return idTokenSource; |
50 | 31 | } |
51 | 32 | } |
52 | 33 |
|
| 34 | + public DefaultCredentialsProvider() {} |
| 35 | + |
| 36 | + public String authType() { |
| 37 | + return authType; |
| 38 | + } |
| 39 | + |
53 | 40 | @Override |
54 | 41 | public synchronized HeaderFactory configure(DatabricksConfig config) { |
| 42 | + addDefaultCredentialsProviders(config); |
55 | 43 | for (CredentialsProvider provider : providers) { |
56 | 44 | if (config.getAuthType() != null |
57 | 45 | && !config.getAuthType().isEmpty() |
@@ -80,4 +68,57 @@ public synchronized HeaderFactory configure(DatabricksConfig config) { |
80 | 68 | + authFlowUrl |
81 | 69 | + " to configure credentials for your preferred authentication method"); |
82 | 70 | } |
| 71 | + |
| 72 | + private void addOIDCCredentialsProviders(DatabricksConfig config) { |
| 73 | + OpenIDConnectEndpoints endpoints = null; |
| 74 | + try { |
| 75 | + endpoints = config.getOidcEndpoints(); |
| 76 | + } catch (Exception e) { |
| 77 | + LOG.warn("Failed to get OpenID Connect endpoints", e); |
| 78 | + } |
| 79 | + |
| 80 | + List<NamedIDTokenSource> namedIdTokenSources = new ArrayList<>(); |
| 81 | + namedIdTokenSources.add( |
| 82 | + new NamedIDTokenSource( |
| 83 | + "github-oidc", |
| 84 | + new GithubIDTokenSource( |
| 85 | + config.getActionsIdTokenRequestUrl(), |
| 86 | + config.getActionsIdTokenRequestToken(), |
| 87 | + config.getHttpClient()))); |
| 88 | + // Add new IDTokenSources and ID providers here. Example: |
| 89 | + // namedIdTokenSources.add(new NamedIDTokenSource("custom-oidc", new CustomIDTokenSource(...))); |
| 90 | + |
| 91 | + for (NamedIDTokenSource namedIdTokenSource : namedIdTokenSources) { |
| 92 | + DatabricksOAuthTokenSource oauthTokenSource = |
| 93 | + new DatabricksOAuthTokenSource.Builder( |
| 94 | + config.getClientId(), |
| 95 | + config.getHost(), |
| 96 | + endpoints, |
| 97 | + namedIdTokenSource.getIdTokenSource(), |
| 98 | + config.getHttpClient()) |
| 99 | + .audience(config.getTokenAudience()) |
| 100 | + .accountId(config.isAccountClient() ? config.getAccountId() : null) |
| 101 | + .build(); |
| 102 | + |
| 103 | + providers.add( |
| 104 | + new TokenSourceCredentialsProvider(oauthTokenSource, namedIdTokenSource.getName())); |
| 105 | + } |
| 106 | + } |
| 107 | + |
| 108 | + private void addDefaultCredentialsProviders(DatabricksConfig config) { |
| 109 | + providers.add(new PatCredentialsProvider()); |
| 110 | + providers.add(new BasicCredentialsProvider()); |
| 111 | + providers.add(new OAuthM2MServicePrincipalCredentialsProvider()); |
| 112 | + |
| 113 | + addOIDCCredentialsProviders(config); |
| 114 | + |
| 115 | + providers.add(new AzureGithubOidcCredentialsProvider()); |
| 116 | + providers.add(new AzureServicePrincipalCredentialsProvider()); |
| 117 | + providers.add(new AzureCliCredentialsProvider()); |
| 118 | + providers.add(new ExternalBrowserCredentialsProvider()); |
| 119 | + providers.add(new DatabricksCliCredentialsProvider()); |
| 120 | + providers.add(new NotebookNativeCredentialsProvider()); |
| 121 | + providers.add(new GoogleCredentialsCredentialsProvider()); |
| 122 | + providers.add(new GoogleIdCredentialsProvider()); |
| 123 | + } |
83 | 124 | } |
0 commit comments