Skip to content

Commit 8279fa1

Browse files
committed
Add token federation and caching layer
This PR adds the federation and caching layer for token providers. This is the second of three PRs implementing token federation support. New components: - CachedTokenProvider: Wraps providers with automatic caching - Configurable refresh threshold (default 5 minutes before expiry) - Thread-safe handling of concurrent requests - clearCache() method for manual invalidation - FederationProvider: Wraps providers with RFC 8693 token exchange - Automatically exchanges external IdP tokens for Databricks tokens - Compares JWT issuer with Databricks host to determine if exchange needed - Graceful fallback to original token on exchange failure - Supports optional clientId for M2M/service principal federation - utils.ts: JWT decoding and host comparison utilities - decodeJWT: Decode JWT payload without verification - getJWTIssuer: Extract issuer from JWT - isSameHost: Compare hostnames ignoring ports New connection options: - enableTokenFederation: Enable automatic token exchange - federationClientId: Client ID for M2M federation
1 parent ba4d0b4 commit 8279fa1

File tree

9 files changed

+856
-3
lines changed

9 files changed

+856
-3
lines changed

lib/DBSQLClient.ts

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ import {
2323
TokenProviderAuthenticator,
2424
StaticTokenProvider,
2525
ExternalTokenProvider,
26+
CachedTokenProvider,
27+
FederationProvider,
28+
ITokenProvider,
2629
} from './connection/auth/tokenProvider';
2730
import IDBSQLLogger, { LogLevel } from './contracts/IDBSQLLogger';
2831
import DBSQLLogger from './DBSQLLogger';
@@ -149,15 +152,47 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient, I
149152
case 'custom':
150153
return options.provider;
151154
case 'token-provider':
152-
return new TokenProviderAuthenticator(options.tokenProvider, this);
155+
return new TokenProviderAuthenticator(
156+
this.wrapTokenProvider(options.tokenProvider, options.host, options.enableTokenFederation, options.federationClientId),
157+
this,
158+
);
153159
case 'external-token':
154-
return new TokenProviderAuthenticator(new ExternalTokenProvider(options.getToken), this);
160+
return new TokenProviderAuthenticator(
161+
this.wrapTokenProvider(new ExternalTokenProvider(options.getToken), options.host, options.enableTokenFederation, options.federationClientId),
162+
this,
163+
);
155164
case 'static-token':
156-
return new TokenProviderAuthenticator(StaticTokenProvider.fromJWT(options.staticToken), this);
165+
return new TokenProviderAuthenticator(
166+
this.wrapTokenProvider(StaticTokenProvider.fromJWT(options.staticToken), options.host, options.enableTokenFederation, options.federationClientId),
167+
this,
168+
);
157169
// no default
158170
}
159171
}
160172

173+
/**
174+
* Wraps a token provider with caching and optional federation.
175+
* Caching is always enabled by default. Federation is opt-in.
176+
*/
177+
private wrapTokenProvider(
178+
provider: ITokenProvider,
179+
host: string,
180+
enableFederation?: boolean,
181+
federationClientId?: string,
182+
): ITokenProvider {
183+
// Always wrap with caching first
184+
let wrapped: ITokenProvider = new CachedTokenProvider(provider);
185+
186+
// Optionally wrap with federation
187+
if (enableFederation) {
188+
wrapped = new FederationProvider(wrapped, host, {
189+
clientId: federationClientId,
190+
});
191+
}
192+
193+
return wrapped;
194+
}
195+
161196
private createConnectionProvider(options: ConnectionOptions): IConnectionProvider {
162197
return new HttpConnection(this.getConnectionOptions(options), this);
163198
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import ITokenProvider from './ITokenProvider';
2+
import Token from './Token';
3+
4+
/**
5+
* Default refresh threshold in milliseconds (5 minutes).
6+
* Tokens will be refreshed when they are within this threshold of expiring.
7+
*/
8+
const DEFAULT_REFRESH_THRESHOLD_MS = 5 * 60 * 1000;
9+
10+
/**
11+
* A token provider that wraps another provider with automatic caching.
12+
* Tokens are cached and reused until they are close to expiring.
13+
*/
14+
export default class CachedTokenProvider implements ITokenProvider {
15+
private readonly baseProvider: ITokenProvider;
16+
17+
private readonly refreshThresholdMs: number;
18+
19+
private cache: Token | null = null;
20+
21+
private refreshPromise: Promise<Token> | null = null;
22+
23+
/**
24+
* Creates a new CachedTokenProvider.
25+
* @param baseProvider - The underlying token provider to cache
26+
* @param options - Optional configuration
27+
* @param options.refreshThresholdMs - Refresh tokens this many ms before expiry (default: 5 minutes)
28+
*/
29+
constructor(
30+
baseProvider: ITokenProvider,
31+
options?: {
32+
refreshThresholdMs?: number;
33+
},
34+
) {
35+
this.baseProvider = baseProvider;
36+
this.refreshThresholdMs = options?.refreshThresholdMs ?? DEFAULT_REFRESH_THRESHOLD_MS;
37+
}
38+
39+
async getToken(): Promise<Token> {
40+
// Return cached token if it's still valid
41+
if (this.cache && !this.shouldRefresh(this.cache)) {
42+
return this.cache;
43+
}
44+
45+
// If already refreshing, wait for that to complete
46+
if (this.refreshPromise) {
47+
return this.refreshPromise;
48+
}
49+
50+
// Start refresh
51+
this.refreshPromise = this.refreshToken();
52+
53+
try {
54+
const token = await this.refreshPromise;
55+
return token;
56+
} finally {
57+
this.refreshPromise = null;
58+
}
59+
}
60+
61+
getName(): string {
62+
return `cached[${this.baseProvider.getName()}]`;
63+
}
64+
65+
/**
66+
* Clears the cached token, forcing a refresh on the next getToken() call.
67+
*/
68+
clearCache(): void {
69+
this.cache = null;
70+
}
71+
72+
/**
73+
* Determines if the token should be refreshed.
74+
* @param token - The token to check
75+
* @returns true if the token should be refreshed
76+
*/
77+
private shouldRefresh(token: Token): boolean {
78+
// If no expiration is known, don't refresh proactively
79+
if (!token.expiresAt) {
80+
return false;
81+
}
82+
83+
const now = Date.now();
84+
const expiresAtMs = token.expiresAt.getTime();
85+
const refreshAtMs = expiresAtMs - this.refreshThresholdMs;
86+
87+
return now >= refreshAtMs;
88+
}
89+
90+
/**
91+
* Fetches a new token from the base provider and caches it.
92+
*/
93+
private async refreshToken(): Promise<Token> {
94+
const token = await this.baseProvider.getToken();
95+
this.cache = token;
96+
return token;
97+
}
98+
}
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
import fetch from 'node-fetch';
2+
import ITokenProvider from './ITokenProvider';
3+
import Token from './Token';
4+
import { decodeJWT, getJWTIssuer, isSameHost } from './utils';
5+
6+
/**
7+
* Token exchange endpoint path for Databricks OIDC.
8+
*/
9+
const TOKEN_EXCHANGE_ENDPOINT = '/oidc/v1/token';
10+
11+
/**
12+
* Grant type for RFC 8693 token exchange.
13+
*/
14+
const TOKEN_EXCHANGE_GRANT_TYPE = 'urn:ietf:params:oauth:grant-type:token-exchange';
15+
16+
/**
17+
* Subject token type for JWT tokens.
18+
*/
19+
const SUBJECT_TOKEN_TYPE = 'urn:ietf:params:oauth:token-type:jwt';
20+
21+
/**
22+
* Default scope for SQL operations.
23+
*/
24+
const DEFAULT_SCOPE = 'sql';
25+
26+
/**
27+
* Timeout for token exchange requests in milliseconds.
28+
*/
29+
const REQUEST_TIMEOUT_MS = 30000;
30+
31+
/**
32+
* A token provider that wraps another provider with automatic token federation.
33+
* When the base provider returns a token from a different issuer, this provider
34+
* exchanges it for a Databricks-compatible token using RFC 8693.
35+
*/
36+
export default class FederationProvider implements ITokenProvider {
37+
private readonly baseProvider: ITokenProvider;
38+
39+
private readonly databricksHost: string;
40+
41+
private readonly clientId?: string;
42+
43+
private readonly returnOriginalTokenOnFailure: boolean;
44+
45+
/**
46+
* Creates a new FederationProvider.
47+
* @param baseProvider - The underlying token provider
48+
* @param databricksHost - The Databricks workspace host URL
49+
* @param options - Optional configuration
50+
* @param options.clientId - Client ID for M2M/service principal federation
51+
* @param options.returnOriginalTokenOnFailure - Return original token if exchange fails (default: true)
52+
*/
53+
constructor(
54+
baseProvider: ITokenProvider,
55+
databricksHost: string,
56+
options?: {
57+
clientId?: string;
58+
returnOriginalTokenOnFailure?: boolean;
59+
},
60+
) {
61+
this.baseProvider = baseProvider;
62+
this.databricksHost = databricksHost;
63+
this.clientId = options?.clientId;
64+
this.returnOriginalTokenOnFailure = options?.returnOriginalTokenOnFailure ?? true;
65+
}
66+
67+
async getToken(): Promise<Token> {
68+
const token = await this.baseProvider.getToken();
69+
70+
// Check if token needs exchange
71+
if (!this.needsTokenExchange(token)) {
72+
return token;
73+
}
74+
75+
// Attempt token exchange
76+
try {
77+
return await this.exchangeToken(token);
78+
} catch (error) {
79+
if (this.returnOriginalTokenOnFailure) {
80+
// Fall back to original token
81+
return token;
82+
}
83+
throw error;
84+
}
85+
}
86+
87+
getName(): string {
88+
return `federated[${this.baseProvider.getName()}]`;
89+
}
90+
91+
/**
92+
* Determines if the token needs to be exchanged.
93+
* @param token - The token to check
94+
* @returns true if the token should be exchanged
95+
*/
96+
private needsTokenExchange(token: Token): boolean {
97+
const issuer = getJWTIssuer(token.accessToken);
98+
99+
// If we can't extract the issuer, don't exchange (might not be a JWT)
100+
if (!issuer) {
101+
return false;
102+
}
103+
104+
// If the issuer is the same as Databricks host, no exchange needed
105+
if (isSameHost(issuer, this.databricksHost)) {
106+
return false;
107+
}
108+
109+
return true;
110+
}
111+
112+
/**
113+
* Exchanges the token for a Databricks-compatible token using RFC 8693.
114+
* @param token - The token to exchange
115+
* @returns The exchanged token
116+
*/
117+
private async exchangeToken(token: Token): Promise<Token> {
118+
const url = this.buildExchangeUrl();
119+
120+
const params = new URLSearchParams({
121+
grant_type: TOKEN_EXCHANGE_GRANT_TYPE,
122+
subject_token_type: SUBJECT_TOKEN_TYPE,
123+
subject_token: token.accessToken,
124+
scope: DEFAULT_SCOPE,
125+
});
126+
127+
if (this.clientId) {
128+
params.append('client_id', this.clientId);
129+
}
130+
131+
const controller = new AbortController();
132+
const timeoutId = setTimeout(() => controller.abort(), REQUEST_TIMEOUT_MS);
133+
134+
try {
135+
const response = await fetch(url, {
136+
method: 'POST',
137+
headers: {
138+
'Content-Type': 'application/x-www-form-urlencoded',
139+
},
140+
body: params.toString(),
141+
signal: controller.signal,
142+
});
143+
144+
if (!response.ok) {
145+
const errorText = await response.text();
146+
throw new Error(`Token exchange failed: ${response.status} ${response.statusText} - ${errorText}`);
147+
}
148+
149+
const data = (await response.json()) as {
150+
access_token?: string;
151+
token_type?: string;
152+
expires_in?: number;
153+
};
154+
155+
if (!data.access_token) {
156+
throw new Error('Token exchange response missing access_token');
157+
}
158+
159+
// Calculate expiration from expires_in
160+
let expiresAt: Date | undefined;
161+
if (typeof data.expires_in === 'number') {
162+
expiresAt = new Date(Date.now() + data.expires_in * 1000);
163+
}
164+
165+
return new Token(data.access_token, {
166+
tokenType: data.token_type ?? 'Bearer',
167+
expiresAt,
168+
});
169+
} finally {
170+
clearTimeout(timeoutId);
171+
}
172+
}
173+
174+
/**
175+
* Builds the token exchange URL.
176+
*/
177+
private buildExchangeUrl(): string {
178+
let host = this.databricksHost;
179+
180+
// Ensure host has a protocol
181+
if (!host.includes('://')) {
182+
host = `https://${host}`;
183+
}
184+
185+
// Remove trailing slash
186+
if (host.endsWith('/')) {
187+
host = host.slice(0, -1);
188+
}
189+
190+
return `${host}${TOKEN_EXCHANGE_ENDPOINT}`;
191+
}
192+
}

lib/connection/auth/tokenProvider/index.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@ export { default as Token } from './Token';
33
export { default as StaticTokenProvider } from './StaticTokenProvider';
44
export { default as ExternalTokenProvider, TokenCallback } from './ExternalTokenProvider';
55
export { default as TokenProviderAuthenticator } from './TokenProviderAuthenticator';
6+
export { default as CachedTokenProvider } from './CachedTokenProvider';
7+
export { default as FederationProvider } from './FederationProvider';
8+
export { decodeJWT, getJWTIssuer, isSameHost } from './utils';

0 commit comments

Comments
 (0)