From 5f4bc44c5e03da7ddcbab6be8961d6531c29f30c Mon Sep 17 00:00:00 2001 From: anthonyivn2 Date: Wed, 11 Mar 2026 09:55:11 +0800 Subject: [PATCH] Add `--refresh-before` flag to `databricks auth token` Users who use `databricks auth token` as an API key helper (e.g., for Claude Code) get expired tokens because the oauth2 library only refreshes within ~10 seconds of expiry. The new `--refresh-before` flag (e.g., `--refresh-before 5m`) refreshes the token if it expires within the given window. Depends on: https://github.com/databricks/databricks-sdk-go/pull/1532 Resolves #4564 --- cmd/auth/token.go | 11 ++++++++ cmd/auth/token_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++ go.mod | 2 ++ go.sum | 2 -- 4 files changed, 77 insertions(+), 2 deletions(-) diff --git a/cmd/auth/token.go b/cmd/auth/token.go index ca8582bd02..a091e108ed 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -65,6 +65,10 @@ using a client ID and secret is not supported.`, cmd.Flags().DurationVar(&tokenTimeout, "timeout", defaultTimeout, "Timeout for acquiring a token.") + var refreshBefore time.Duration + cmd.Flags().DurationVar(&refreshBefore, "refresh-before", 0, + "Refresh the token if it expires within this duration (e.g., 5m, 30s).") + cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() profileName := "" @@ -78,6 +82,7 @@ using a client ID and secret is not supported.`, profileName: profileName, args: args, tokenTimeout: tokenTimeout, + refreshBefore: refreshBefore, profiler: profile.DefaultProfiler, persistentAuthOpts: nil, }) @@ -108,6 +113,9 @@ type loadTokenArgs struct { // tokenTimeout is the timeout for retrieving (and potentially refreshing) an OAuth token. tokenTimeout time.Duration + // refreshBefore triggers a token refresh if the token expires within this duration. + refreshBefore time.Duration + // profiler is the profiler to use for reading the host and account ID from the .databrickscfg file. profiler profile.Profiler @@ -242,6 +250,9 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { return nil, err } allArgs := append(args.persistentAuthOpts, u2m.WithOAuthArgument(oauthArgument)) + if args.refreshBefore > 0 { + allArgs = append(allArgs, u2m.WithExpiryDelta(args.refreshBefore)) + } persistentAuth, err := u2m.NewPersistentAuth(ctx, allArgs...) if err != nil { helpMsg := helpfulError(ctx, args.profileName, oauthArgument) diff --git a/cmd/auth/token_test.go b/cmd/auth/token_test.go index 3ec69e0e33..58ac823df5 100644 --- a/cmd/auth/token_test.go +++ b/cmd/auth/token_test.go @@ -130,6 +130,11 @@ func TestToken_loadToken(t *testing.T) { Name: "legacy-ws", Host: "https://legacy-ws.cloud.databricks.com", }, + { + Name: "valid-token", + Host: "https://accounts.cloud.databricks.com", + AccountID: "valid-token", + }, { Name: "m2m-profile", Host: "https://m2m.cloud.databricks.com", @@ -642,6 +647,65 @@ func TestToken_loadToken(t *testing.T) { }, validateToken: validateToken, }, + { + name: "refreshBefore skips refresh when token has enough time", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{}, + profileName: "valid-token", + args: []string{}, + tokenTimeout: 1 * time.Hour, + refreshBefore: 5 * time.Minute, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(&inMemoryTokenCache{Tokens: map[string]*oauth2.Token{ + "valid-token": {AccessToken: "still-valid", RefreshToken: "valid-token", Expiry: time.Now().Add(1 * time.Hour)}, + }}), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + }, + }, + validateToken: func(resp *oauth2.Token) { + assert.Equal(t, "still-valid", resp.AccessToken) + }, + }, + { + name: "refreshBefore zero preserves default behavior", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{}, + profileName: "valid-token", + args: []string{}, + tokenTimeout: 1 * time.Hour, + refreshBefore: 0, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(&inMemoryTokenCache{Tokens: map[string]*oauth2.Token{ + "valid-token": {AccessToken: "still-valid", RefreshToken: "valid-token", Expiry: time.Now().Add(1 * time.Hour)}, + }}), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + }, + }, + validateToken: func(resp *oauth2.Token) { + assert.Equal(t, "still-valid", resp.AccessToken) + }, + }, + { + name: "refreshBefore forces refresh when token expires within window", + args: loadTokenArgs{ + authArguments: &auth.AuthArguments{}, + profileName: "valid-token", + args: []string{}, + tokenTimeout: 1 * time.Hour, + refreshBefore: 2 * time.Hour, + profiler: profiler, + persistentAuthOpts: []u2m.PersistentAuthOption{ + u2m.WithTokenCache(&inMemoryTokenCache{Tokens: map[string]*oauth2.Token{ + "valid-token": {AccessToken: "still-valid", RefreshToken: "valid-token", Expiry: time.Now().Add(1 * time.Hour)}, + }}), + u2m.WithOAuthEndpointSupplier(&MockApiClient{}), + u2m.WithHttpClient(&http.Client{Transport: fixtures.SliceTransport{refreshSuccessTokenResponse}}), + }, + }, + validateToken: validateToken, + }, { name: "host flag with profile env var disambiguates multi-profile", setupCtx: func(ctx context.Context) context.Context { diff --git a/go.mod b/go.mod index 7f633c0c72..058d6ae6cd 100644 --- a/go.mod +++ b/go.mod @@ -110,3 +110,5 @@ require ( google.golang.org/grpc v1.78.0 // indirect google.golang.org/protobuf v1.36.11 // indirect ) + +replace github.com/databricks/databricks-sdk-go => /Users/anthony.ivan/projects/databricks-sdk-go diff --git a/go.sum b/go.sum index 3cb5d3228f..7c8c0c61e2 100644 --- a/go.sum +++ b/go.sum @@ -75,8 +75,6 @@ github.com/creack/pty v1.1.24 h1:bJrF4RRfyJnbTJqzRLHzcGaZK1NeM5kTC9jGgovnR1s= github.com/creack/pty v1.1.24/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s= github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI= -github.com/databricks/databricks-sdk-go v0.119.0 h1:Fot5T4bBGxfuFHII0xLPXuzkBmALWiJeUBeuXQX2Pcw= -github.com/databricks/databricks-sdk-go v0.119.0/go.mod h1:hWoHnHbNLjPKiTm5K/7bcIv3J3Pkgo5x9pPzh8K3RVE= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=