Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## [Unreleased]

### Added
-
- Support for token cache in OAuth U2M Flow using the configuration parameters: `EnableTokenCache` and `TokenCachePassPhrase`.

### Updated
-
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
<httpclient.version>4.5.14</httpclient.version>
<commons-configuration.version>2.10.1</commons-configuration.version>
<commons-io.version>2.14.0</commons-io.version>
<databricks-sdk.version>0.44.0</databricks-sdk.version>
<databricks-sdk.version>0.46.0</databricks-sdk.version>
<maven-surefire-plugin.version>3.1.2</maven-surefire-plugin.version>
<sql-logic-test.version>0.3</sql-logic-test.version>
<lz4-compression.version>1.8.0</lz4-compression.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,16 @@ public int getSocketTimeout() {
return Integer.parseInt(getParameter(DatabricksJdbcUrlParams.SOCKET_TIMEOUT));
}

@Override
public String getTokenCachePassPhrase() {
return getParameter(DatabricksJdbcUrlParams.TOKEN_CACHE_PASS_PHRASE);
}

@Override
public boolean isTokenCacheEnabled() {
return getParameter(DatabricksJdbcUrlParams.ENABLE_TOKEN_CACHE).equals("1");
}

private static boolean nullOrEmptyString(String s) {
return s == null || s.isEmpty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,4 +293,10 @@ public interface IDatabricksConnectionContext {
* @return true if the system property trust store should be used, false otherwise
*/
boolean useSystemTrustStore();

/** Returns the passphrase used for encrypting/decrypting token cache */
String getTokenCachePassPhrase();

/** Returns whether token caching is enabled for OAuth authentication */
boolean isTokenCacheEnabled();
}
179 changes: 179 additions & 0 deletions src/main/java/com/databricks/jdbc/auth/EncryptedFileTokenCache.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package com.databricks.jdbc.auth;

import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import com.databricks.sdk.core.DatabricksException;
import com.databricks.sdk.core.oauth.Token;
import com.databricks.sdk.core.oauth.TokenCache;
import com.databricks.sdk.core.utils.SerDeUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.SecureRandom;
import java.security.spec.KeySpec;
import java.util.Base64;
import java.util.Objects;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.SecretKeyFactory;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.PBEKeySpec;
import javax.crypto.spec.SecretKeySpec;

/** A TokenCache implementation that stores tokens in encrypted files. */
public class EncryptedFileTokenCache implements TokenCache {
private static final JdbcLogger LOGGER =
JdbcLoggerFactory.getLogger(EncryptedFileTokenCache.class);

// Encryption constants
private static final String ALGORITHM = "AES";
private static final String TRANSFORMATION = "AES/CBC/PKCS5Padding";
private static final String SECRET_KEY_ALGORITHM = "PBKDF2WithHmacSHA256";
private static final byte[] SALT = "DatabricksJdbcTokenCache".getBytes();
private static final int ITERATION_COUNT = 65536;
private static final int KEY_LENGTH = 256;
private static final int IV_SIZE = 16; // 128 bits

private final Path cacheFile;
private final ObjectMapper mapper;
private final String passphrase;

/**
* Constructs a new EncryptingFileTokenCache instance.
*
* @param cacheFilePath The path where the token cache will be stored
* @param passphrase The passphrase used for encryption
*/
public EncryptedFileTokenCache(Path cacheFilePath, String passphrase) {
Objects.requireNonNull(cacheFilePath, "cacheFilePath must be defined");
Objects.requireNonNull(passphrase, "passphrase must be defined for encrypted token cache");

this.cacheFile = cacheFilePath;
this.mapper = SerDeUtils.createMapper();
this.passphrase = passphrase;
}

@Override
public void save(Token token) throws DatabricksException {
try {
Files.createDirectories(cacheFile.getParent());

// Serialize token to JSON
String json = mapper.writeValueAsString(token);
byte[] dataToWrite = json.getBytes(StandardCharsets.UTF_8);

// Encrypt data
dataToWrite = encrypt(dataToWrite);

Files.write(cacheFile, dataToWrite);
// Set file permissions to be readable only by the owner (equivalent to 0600)
File file = cacheFile.toFile();
file.setReadable(false, false);
file.setReadable(true, true);
file.setWritable(false, false);
file.setWritable(true, true);

LOGGER.debug("Successfully saved encrypted token to cache: %s", cacheFile);
} catch (Exception e) {
throw new DatabricksException("Failed to save token cache: " + e.getMessage(), e);
}
}

@Override
public Token load() {
try {
if (!Files.exists(cacheFile)) {
LOGGER.debug("No token cache file found at: %s", cacheFile);
return null;
}

byte[] fileContent = Files.readAllBytes(cacheFile);

// Decrypt data
byte[] decodedContent;
try {
decodedContent = decrypt(fileContent);
} catch (Exception e) {
LOGGER.debug("Failed to decrypt token cache: %s", e.getMessage());
return null;
}

// Deserialize token from JSON
String json = new String(decodedContent, StandardCharsets.UTF_8);
Token token = mapper.readValue(json, Token.class);
LOGGER.debug("Successfully loaded encrypted token from cache: %s", cacheFile);
return token;
} catch (Exception e) {
// If there's any issue loading the token, return null
// to allow a fresh token to be obtained
LOGGER.debug("Failed to load token from cache: %s", e.getMessage());
return null;
}
}

/**
* Generates a secret key from the passphrase using PBKDF2 with HMAC-SHA256.
*
* @return A SecretKey generated from the passphrase
* @throws Exception If an error occurs generating the key
*/
private SecretKey generateSecretKey() throws Exception {
SecretKeyFactory factory = SecretKeyFactory.getInstance(SECRET_KEY_ALGORITHM);
KeySpec spec = new PBEKeySpec(passphrase.toCharArray(), SALT, ITERATION_COUNT, KEY_LENGTH);
return new SecretKeySpec(factory.generateSecret(spec).getEncoded(), ALGORITHM);
}

/**
* Encrypts the given data using AES/CBC/PKCS5Padding encryption with a key derived from the
* passphrase. The IV is generated randomly and prepended to the encrypted data.
*
* @param data The data to encrypt
* @return The encrypted data with IV prepended
* @throws Exception If an error occurs during encryption
*/
private byte[] encrypt(byte[] data) throws Exception {
Cipher cipher = Cipher.getInstance(TRANSFORMATION);

// Generate a random IV
SecureRandom random = new SecureRandom();
byte[] iv = new byte[IV_SIZE];
random.nextBytes(iv);
IvParameterSpec ivSpec = new IvParameterSpec(iv);

cipher.init(Cipher.ENCRYPT_MODE, generateSecretKey(), ivSpec);
byte[] encryptedData = cipher.doFinal(data);

// Combine IV and encrypted data
byte[] combined = new byte[iv.length + encryptedData.length];
System.arraycopy(iv, 0, combined, 0, iv.length);
System.arraycopy(encryptedData, 0, combined, iv.length, encryptedData.length);

return Base64.getEncoder().encode(combined);
}

/**
* Decrypts the given encrypted data using AES/CBC/PKCS5Padding decryption with a key derived from
* the passphrase. The IV is extracted from the beginning of the encrypted data.
*
* @param encryptedData The encrypted data with IV prepended, Base64 encoded
* @return The decrypted data
* @throws Exception If an error occurs during decryption
*/
private byte[] decrypt(byte[] encryptedData) throws Exception {
byte[] decodedData = Base64.getDecoder().decode(encryptedData);

// Extract IV
byte[] iv = new byte[IV_SIZE];
byte[] actualData = new byte[decodedData.length - IV_SIZE];
System.arraycopy(decodedData, 0, iv, 0, IV_SIZE);
System.arraycopy(decodedData, IV_SIZE, actualData, 0, actualData.length);

Cipher cipher = Cipher.getInstance(TRANSFORMATION);
IvParameterSpec ivSpec = new IvParameterSpec(iv);
cipher.init(Cipher.DECRYPT_MODE, generateSecretKey(), ivSpec);

return cipher.doFinal(actualData);
}
}
25 changes: 25 additions & 0 deletions src/main/java/com/databricks/jdbc/auth/NoOpTokenCache.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.databricks.jdbc.auth;

import com.databricks.jdbc.log.JdbcLogger;
import com.databricks.jdbc.log.JdbcLoggerFactory;
import com.databricks.sdk.core.oauth.Token;
import com.databricks.sdk.core.oauth.TokenCache;

/**
* A no-operation implementation of TokenCache that does nothing. Used when token caching is
* explicitly disabled.
*/
public class NoOpTokenCache implements TokenCache {
private static final JdbcLogger LOGGER = JdbcLoggerFactory.getLogger(NoOpTokenCache.class);

@Override
public void save(Token token) {
LOGGER.debug("Token caching is disabled, skipping save operation");
}

@Override
public Token load() {
LOGGER.debug("Token caching is disabled, skipping load operation");
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ public enum DatabricksJdbcUrlParams {
"DefaultStringColumnLength",
"Maximum number of characters that can be contained in STRING columns",
"255"),
SOCKET_TIMEOUT("socketTimeout", "Socket timeout in seconds", "900");
SOCKET_TIMEOUT("socketTimeout", "Socket timeout in seconds", "900"),
TOKEN_CACHE_PASS_PHRASE("TokenCachePassPhrase", "Pass phrase to use for OAuth U2M Token Cache"),
ENABLE_TOKEN_CACHE("EnableTokenCache", "Enable caching OAuth tokens", "1");

private final String paramName;
private final String defaultValue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import static com.databricks.jdbc.common.util.DatabricksAuthUtil.initializeConfigWithToken;

import com.databricks.jdbc.api.internal.IDatabricksConnectionContext;
import com.databricks.jdbc.auth.AzureMSICredentialProvider;
import com.databricks.jdbc.auth.OAuthRefreshCredentialsProvider;
import com.databricks.jdbc.auth.PrivateKeyClientCredentialProvider;
import com.databricks.jdbc.auth.*;
import com.databricks.jdbc.common.AuthMech;
import com.databricks.jdbc.common.DatabricksJdbcConstants;
import com.databricks.jdbc.common.util.DriverUtil;
Expand All @@ -20,9 +18,13 @@
import com.databricks.sdk.core.DatabricksException;
import com.databricks.sdk.core.ProxyConfig;
import com.databricks.sdk.core.commons.CommonsHttpClient;
import com.databricks.sdk.core.oauth.ExternalBrowserCredentialsProvider;
import com.databricks.sdk.core.oauth.TokenCache;
import com.databricks.sdk.core.utils.Cloud;
import java.io.IOException;
import java.net.ServerSocket;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -52,6 +54,53 @@ public ClientConfigurator(IDatabricksConnectionContext connectionContext) {
this.databricksConfig.resolve();
}

/**
* Returns the path for the token cache file based on host, client ID, and scopes. This creates a
* unique cache path using a hash of these parameters.
*
* @param host The host URL
* @param clientId The OAuth client ID
* @param scopes The OAuth scopes
* @return The path for the token cache file
*/
public static Path getTokenCachePath(String host, String clientId, List<String> scopes) {
String userHome = System.getProperty("user.home");
Path homeDir = Paths.get(userHome);
Path databricksDir = homeDir.resolve(".config/databricks-jdbc/oauth");

// Create a unique string identifier from the combination of parameters
String uniqueIdentifier = createUniqueIdentifier(host, clientId, scopes);

String filename = "token-cache-" + uniqueIdentifier;

return databricksDir.resolve(filename);
}

/**
* Creates a unique identifier string from the given parameters. Uses a hash function to create a
* compact representation.
*
* @param host The host URL
* @param clientId The OAuth client ID
* @param scopes The OAuth scopes
* @return A unique identifier string
*/
private static String createUniqueIdentifier(String host, String clientId, List<String> scopes) {
// Normalize inputs to handle null values
host = (host != null) ? host : EMPTY_STRING;
clientId = (clientId != null) ? clientId : EMPTY_STRING;
scopes = (scopes != null) ? scopes : List.of();

// Combine all parameters
String combined = host + URL_DELIMITER + clientId + URL_DELIMITER + String.join(COMMA, scopes);

// Create a hash from the combined string
int hash = combined.hashCode();

// Convert to a positive hexadecimal string
return Integer.toHexString(hash & 0x7FFFFFFF);
}

/**
* Setup the SSL configuration in the httpClientBuilder.
*
Expand Down Expand Up @@ -136,10 +185,13 @@ public void setupU2MConfig() throws DatabricksParsingException {
int redirectPort = findAvailablePort(connectionContext.getOAuth2RedirectUrlPorts());
String redirectUrl = String.format("http://localhost:%d", redirectPort);

String host = connectionContext.getHostForOAuth();
String clientId = connectionContext.getClientId();

databricksConfig
.setAuthType(DatabricksJdbcConstants.U2M_AUTH_TYPE)
.setHost(connectionContext.getHostForOAuth())
.setClientId(connectionContext.getClientId())
.setHost(host)
.setClientId(clientId)
.setClientSecret(connectionContext.getClientSecret())
.setOAuthRedirectUrl(redirectUrl);

Expand All @@ -148,6 +200,21 @@ public void setupU2MConfig() throws DatabricksParsingException {
if (!databricksConfig.isAzure()) {
databricksConfig.setScopes(connectionContext.getOAuthScopesForU2M());
}

TokenCache tokenCache;
if (connectionContext.isTokenCacheEnabled()) {
if (connectionContext.getTokenCachePassPhrase() == null) {
LOGGER.error("No token cache passphrase configured");
throw new DatabricksException("No token cache passphrase configured");
}
Path tokenCachePath = getTokenCachePath(host, clientId, databricksConfig.getScopes());
tokenCache =
new EncryptedFileTokenCache(tokenCachePath, connectionContext.getTokenCachePassPhrase());
} else {
tokenCache = new NoOpTokenCache();
}
CredentialsProvider provider = new ExternalBrowserCredentialsProvider(tokenCache);
databricksConfig.setCredentialsProvider(provider).setAuthType(provider.authType());
}

/**
Expand Down Expand Up @@ -229,14 +296,13 @@ public void resetAccessTokenInConfig(String newAccessToken) {

/** Setup the OAuth U2M refresh token authentication settings in the databricks config. */
public void setupU2MRefreshConfig() throws DatabricksParsingException {
CredentialsProvider provider =
new OAuthRefreshCredentialsProvider(connectionContext, databricksConfig);
databricksConfig
.setHost(connectionContext.getHostForOAuth())
.setAuthType(provider.authType()) // oauth-refresh
.setCredentialsProvider(provider)
.setClientId(connectionContext.getClientId())
.setClientSecret(connectionContext.getClientSecret());
CredentialsProvider provider =
new OAuthRefreshCredentialsProvider(connectionContext, databricksConfig);
databricksConfig.setAuthType(provider.authType()).setCredentialsProvider(provider);
}

/** Setup the OAuth M2M authentication settings in the databricks config. */
Expand Down
Loading
Loading