diff --git a/acceptance/cmd/auth/login/configure-serverless/out.databrickscfg b/acceptance/cmd/auth/login/configure-serverless/out.databrickscfg index cf1c187658..b095b683c9 100644 --- a/acceptance/cmd/auth/login/configure-serverless/out.databrickscfg +++ b/acceptance/cmd/auth/login/configure-serverless/out.databrickscfg @@ -1,4 +1,5 @@ [DEFAULT] host = [DATABRICKS_URL] serverless_compute_id = auto +workspace_id = [NUMID] auth_type = databricks-cli diff --git a/acceptance/cmd/auth/login/configure-serverless/output.txt b/acceptance/cmd/auth/login/configure-serverless/output.txt index 5a36dab005..bb29a75f69 100644 --- a/acceptance/cmd/auth/login/configure-serverless/output.txt +++ b/acceptance/cmd/auth/login/configure-serverless/output.txt @@ -12,4 +12,5 @@ Profile DEFAULT was successfully saved [DEFAULT] host = [DATABRICKS_URL] serverless_compute_id = auto +workspace_id = [NUMID] auth_type = databricks-cli diff --git a/acceptance/cmd/auth/login/custom-config-file/out.databrickscfg b/acceptance/cmd/auth/login/custom-config-file/out.databrickscfg index e228e10bc6..2097f1a344 100644 --- a/acceptance/cmd/auth/login/custom-config-file/out.databrickscfg +++ b/acceptance/cmd/auth/login/custom-config-file/out.databrickscfg @@ -2,5 +2,6 @@ [DEFAULT] [custom-test] -host = [DATABRICKS_URL] -auth_type = databricks-cli +host = [DATABRICKS_URL] +auth_type = databricks-cli +workspace_id = [NUMID] diff --git a/acceptance/cmd/auth/login/custom-config-file/output.txt b/acceptance/cmd/auth/login/custom-config-file/output.txt index 79e67f48b2..f6657d7c6f 100644 --- a/acceptance/cmd/auth/login/custom-config-file/output.txt +++ b/acceptance/cmd/auth/login/custom-config-file/output.txt @@ -17,5 +17,6 @@ OK: Default .databrickscfg does not exist [DEFAULT] [custom-test] -host = [DATABRICKS_URL] -auth_type = databricks-cli +host = [DATABRICKS_URL] +auth_type = databricks-cli +workspace_id = [NUMID] diff --git a/acceptance/cmd/auth/login/host-arg-overrides-profile/out.databrickscfg b/acceptance/cmd/auth/login/host-arg-overrides-profile/out.databrickscfg index 0b20cb5f03..b72d6351e2 100644 --- a/acceptance/cmd/auth/login/host-arg-overrides-profile/out.databrickscfg +++ b/acceptance/cmd/auth/login/host-arg-overrides-profile/out.databrickscfg @@ -2,5 +2,6 @@ [DEFAULT] [override-test] -host = [DATABRICKS_URL] -auth_type = databricks-cli +host = [DATABRICKS_URL] +workspace_id = [NUMID] +auth_type = databricks-cli diff --git a/acceptance/cmd/auth/login/host-arg-overrides-profile/output.txt b/acceptance/cmd/auth/login/host-arg-overrides-profile/output.txt index b13e876dd0..79371465c4 100644 --- a/acceptance/cmd/auth/login/host-arg-overrides-profile/output.txt +++ b/acceptance/cmd/auth/login/host-arg-overrides-profile/output.txt @@ -12,5 +12,6 @@ Profile override-test was successfully saved [DEFAULT] [override-test] -host = [DATABRICKS_URL] -auth_type = databricks-cli +host = [DATABRICKS_URL] +workspace_id = [NUMID] +auth_type = databricks-cli diff --git a/acceptance/cmd/auth/login/host-from-profile/out.databrickscfg b/acceptance/cmd/auth/login/host-from-profile/out.databrickscfg index 5b725bca90..0c13bde257 100644 --- a/acceptance/cmd/auth/login/host-from-profile/out.databrickscfg +++ b/acceptance/cmd/auth/login/host-from-profile/out.databrickscfg @@ -2,5 +2,6 @@ [DEFAULT] [existing-profile] -host = [DATABRICKS_URL] -auth_type = databricks-cli +host = [DATABRICKS_URL] +workspace_id = [NUMID] +auth_type = databricks-cli diff --git a/acceptance/cmd/auth/login/host-from-profile/output.txt b/acceptance/cmd/auth/login/host-from-profile/output.txt index 5683c925bf..6faae38ae5 100644 --- a/acceptance/cmd/auth/login/host-from-profile/output.txt +++ b/acceptance/cmd/auth/login/host-from-profile/output.txt @@ -12,5 +12,6 @@ Profile existing-profile was successfully saved [DEFAULT] [existing-profile] -host = [DATABRICKS_URL] -auth_type = databricks-cli +host = [DATABRICKS_URL] +workspace_id = [NUMID] +auth_type = databricks-cli diff --git a/acceptance/cmd/auth/login/nominal/out.databrickscfg b/acceptance/cmd/auth/login/nominal/out.databrickscfg index d985d710b4..d94ee2221d 100644 --- a/acceptance/cmd/auth/login/nominal/out.databrickscfg +++ b/acceptance/cmd/auth/login/nominal/out.databrickscfg @@ -2,8 +2,9 @@ [DEFAULT] [test] -host = [DATABRICKS_URL] -auth_type = databricks-cli +host = [DATABRICKS_URL] +workspace_id = [NUMID] +auth_type = databricks-cli [__settings__] default_profile = test diff --git a/acceptance/cmd/auth/login/preserve-fields/output.txt b/acceptance/cmd/auth/login/preserve-fields/output.txt index 625efd0bf2..28c0ed5660 100644 --- a/acceptance/cmd/auth/login/preserve-fields/output.txt +++ b/acceptance/cmd/auth/login/preserve-fields/output.txt @@ -18,4 +18,5 @@ cluster_id = existing-cluster-123 warehouse_id = warehouse-456 azure_environment = USGOVERNMENT custom_key = my-custom-value +workspace_id = [NUMID] auth_type = databricks-cli diff --git a/acceptance/cmd/auth/login/with-scopes/out.databrickscfg b/acceptance/cmd/auth/login/with-scopes/out.databrickscfg index 15911616ac..c8af1832d7 100644 --- a/acceptance/cmd/auth/login/with-scopes/out.databrickscfg +++ b/acceptance/cmd/auth/login/with-scopes/out.databrickscfg @@ -2,9 +2,10 @@ [DEFAULT] [scoped-test] -host = [DATABRICKS_URL] -scopes = jobs,pipelines,clusters -auth_type = databricks-cli +host = [DATABRICKS_URL] +workspace_id = [NUMID] +scopes = jobs,pipelines,clusters +auth_type = databricks-cli [__settings__] default_profile = scoped-test diff --git a/cmd/auth/login.go b/cmd/auth/login.go index 1a1240d84c..c97959b6b1 100644 --- a/cmd/auth/login.go +++ b/cmd/auth/login.go @@ -5,7 +5,9 @@ import ( "errors" "fmt" "io" + "net/url" "runtime" + "strconv" "strings" "time" @@ -16,6 +18,7 @@ import ( "github.com/databricks/cli/libs/databrickscfg/profile" "github.com/databricks/cli/libs/env" "github.com/databricks/cli/libs/exec" + "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/config/experimental/auth/authconv" @@ -100,6 +103,7 @@ depends on the existing profiles you have set in your configuration file var loginTimeout time.Duration var configureCluster bool var configureServerless bool + var skipWorkspace bool var scopes string cmd.Flags().DurationVar(&loginTimeout, "timeout", defaultTimeout, "Timeout for completing login challenge in the browser") @@ -107,6 +111,8 @@ depends on the existing profiles you have set in your configuration file "Prompts to configure cluster") cmd.Flags().BoolVar(&configureServerless, "configure-serverless", false, "Prompts to configure serverless") + cmd.Flags().BoolVar(&skipWorkspace, "skip-workspace", false, + "Skip workspace selection for account-level access") cmd.Flags().StringVar(&scopes, "scopes", "", "Comma-separated list of OAuth scopes to request (defaults to 'all-apis')") @@ -138,13 +144,12 @@ depends on the existing profiles you have set in your configuration file return err } - // Load unified host flags from the profile if not explicitly set via CLI flag + // Load unified host flag from the profile if not explicitly set via CLI flag. + // WorkspaceID is NOT loaded here; it is deferred to setHostAndAccountId() + // so that URL query params (?o=...) can override stale profile values. if !cmd.Flag("experimental-is-unified-host").Changed && existingProfile != nil { authArguments.IsUnifiedHost = existingProfile.IsUnifiedHost } - if !cmd.Flag("workspace-id").Changed && existingProfile != nil { - authArguments.WorkspaceID = existingProfile.WorkspaceID - } err = setHostAndAccountId(ctx, existingProfile, authArguments, args) if err != nil { @@ -193,8 +198,36 @@ depends on the existing profiles you have set in your configuration file } // At this point, an OAuth token has been successfully minted and stored // in the CLI cache. The rest of the command focuses on: - // 1. Configuring cluster and serverless; - // 2. Saving the profile. + // 1. Workspace selection for SPOG hosts (best-effort); + // 2. Configuring cluster and serverless; + // 3. Saving the profile. + + // For SPOG hosts with account_id but no workspace_id, prompt for workspace selection. + // This is skipped for classic accounts.* hosts where account-level access is expected. + // Skip workspace selection if: + // - --skip-workspace flag is set + // - workspace_id is already set (including "none" sentinel from a previous login) + cfg := &config.Config{Host: authArguments.Host} + shouldPromptWorkspace := authArguments.AccountID != "" && + authArguments.WorkspaceID == "" && + cfg.HostType() != config.AccountHost && + !skipWorkspace + + if skipWorkspace && authArguments.WorkspaceID == "" { + authArguments.WorkspaceID = auth.WorkspaceIDNone + } + + if shouldPromptWorkspace { + wsID, wsErr := promptForWorkspaceSelection(ctx, authArguments, persistentAuth) + if wsErr != nil { + log.Warnf(ctx, "Workspace selection failed: %v", wsErr) + } else if wsID == "" { + // User selected "Skip" from the prompt. + authArguments.WorkspaceID = auth.WorkspaceIDNone + } else { + authArguments.WorkspaceID = wsID + } + } var clusterID, serverlessComputeID string @@ -304,6 +337,22 @@ func setHostAndAccountId(ctx context.Context, existingProfile *profile.Profile, authArguments.Host = strings.TrimSuffix(authArguments.Host, "/") + // Extract query parameters from the host URL (?o=workspace_id, ?a=account_id). + // URL params from explicit --host override stale profile values. + extractHostQueryParams(authArguments) + + // Inherit workspace_id from the existing profile AFTER URL param extraction. + // This ensures URL params (?o=...) take precedence over stale profile values, + // while explicit CLI flags (--workspace-id) still win (already set on authArguments). + if authArguments.WorkspaceID == "" && existingProfile != nil && existingProfile.WorkspaceID != "" { + authArguments.WorkspaceID = existingProfile.WorkspaceID + } + + // Call discovery to populate account_id/workspace_id from the host's + // .well-known/databricks-config endpoint. This is best-effort: failures + // are logged as warnings and never block login. + runHostDiscovery(ctx, authArguments) + // Determine the host type and handle account ID / workspace ID accordingly cfg := &config.Config{ Host: authArguments.Host, @@ -314,7 +363,7 @@ func setHostAndAccountId(ctx context.Context, existingProfile *profile.Profile, switch cfg.HostType() { case config.AccountHost: - // Account host - prompt for account ID if not provided + // Account host: prompt for account ID if not provided if authArguments.AccountID == "" { if existingProfile != nil && existingProfile.AccountID != "" { authArguments.AccountID = existingProfile.AccountID @@ -327,7 +376,7 @@ func setHostAndAccountId(ctx context.Context, existingProfile *profile.Profile, } } case config.UnifiedHost: - // Unified host requires an account ID for OAuth URL construction + // Unified host requires an account ID for OAuth URL construction. if authArguments.AccountID == "" { if existingProfile != nil && existingProfile.AccountID != "" { authArguments.AccountID = existingProfile.AccountID @@ -340,16 +389,12 @@ func setHostAndAccountId(ctx context.Context, existingProfile *profile.Profile, } } - // Workspace ID is optional and determines API access level: - // - With workspace ID: workspace-level APIs - // - Without workspace ID: account-level APIs - // If neither is provided via flags, prompt for workspace ID (most common case) - hasWorkspaceID := authArguments.WorkspaceID != "" - if !hasWorkspaceID { + // Workspace ID is optional: with it you get workspace-level APIs, + // without it you get account-level APIs. + if authArguments.WorkspaceID == "" { if existingProfile != nil && existingProfile.WorkspaceID != "" { authArguments.WorkspaceID = existingProfile.WorkspaceID } else { - // Prompt for workspace ID for workspace-level access workspaceId, err := promptForWorkspaceID(ctx) if err != nil { return err @@ -358,7 +403,8 @@ func setHostAndAccountId(ctx context.Context, existingProfile *profile.Profile, } } case config.WorkspaceHost: - // Workspace host - no additional prompts needed + // Regular workspace host: no additional prompts needed. + // If discovery already populated account_id/workspace_id, those are kept. default: return fmt.Errorf("unknown host type: %v", cfg.HostType()) } @@ -366,6 +412,84 @@ func setHostAndAccountId(ctx context.Context, existingProfile *profile.Profile, return nil } +// extractHostQueryParams parses query parameters from the host URL. +// Recognized parameters: o (workspace_id), a (account_id), account_id, workspace_id. +// The host is stripped of all query parameters after extraction. +// Only sets values not already present (explicit flags take precedence). +func extractHostQueryParams(authArguments *auth.AuthArguments) { + u, err := url.Parse(authArguments.Host) + if err != nil || u.RawQuery == "" { + return + } + + q := u.Query() + + // Extract workspace_id from ?o= or ?workspace_id=. + // Workspace IDs are always numeric, so skip non-numeric values to avoid + // confusing downstream errors. + if authArguments.WorkspaceID == "" { + if v := q.Get("o"); v != "" { + if _, err := strconv.ParseInt(v, 10, 64); err == nil { + authArguments.WorkspaceID = v + } + } else if v := q.Get("workspace_id"); v != "" { + if _, err := strconv.ParseInt(v, 10, 64); err == nil { + authArguments.WorkspaceID = v + } + } + } + + // Extract account_id from ?a=, ?account_id= + if authArguments.AccountID == "" { + if v := q.Get("a"); v != "" { + authArguments.AccountID = v + } else if v := q.Get("account_id"); v != "" { + authArguments.AccountID = v + } + } + + // Strip query params from host + u.RawQuery = "" + u.Fragment = "" + u.Path = strings.TrimSuffix(u.Path, "/") + authArguments.Host = u.String() +} + +// runHostDiscovery calls EnsureResolved() with a temporary config to fetch +// .well-known/databricks-config from the host. Populates account_id and +// workspace_id from discovery if not already set. +func runHostDiscovery(ctx context.Context, authArguments *auth.AuthArguments) { + if authArguments.Host == "" { + return + } + + cfg := &config.Config{ + Host: authArguments.Host, + AccountID: authArguments.AccountID, + WorkspaceID: authArguments.WorkspaceID, + HTTPTimeoutSeconds: 5, + // Use only ConfigAttributes (env vars + struct tags), skip config file + // loading to avoid interference from existing profiles. + Loaders: []config.Loader{config.ConfigAttributes}, + } + + err := cfg.EnsureResolved() + if err != nil { + log.Warnf(ctx, "Host metadata discovery failed: %v", err) + return + } + + if authArguments.AccountID == "" && cfg.AccountID != "" { + authArguments.AccountID = cfg.AccountID + } + if authArguments.WorkspaceID == "" && cfg.WorkspaceID != "" { + authArguments.WorkspaceID = cfg.WorkspaceID + } + if authArguments.DiscoveryURL == "" && cfg.DiscoveryURL != "" { + authArguments.DiscoveryURL = cfg.DiscoveryURL + } +} + // getProfileName returns the default profile name for a given host/account ID. // If the account ID is provided, the profile name is "ACCOUNT-". // Otherwise, the profile name is the first part of the host URL. @@ -423,6 +547,65 @@ func oauthLoginClearKeys() []string { return databrickscfg.AuthCredentialKeys() } +// promptForWorkspaceSelection lists workspaces for a SPOG account and lets the +// user pick one. Returns the selected workspace ID or empty string if skipped. +// This is best-effort: errors are returned to the caller for logging, not shown +// to the user. +func promptForWorkspaceSelection(ctx context.Context, authArguments *auth.AuthArguments, persistentAuth *u2m.PersistentAuth) (string, error) { + if !cmdio.IsPromptSupported(ctx) { + return "", nil + } + + a, err := databricks.NewAccountClient(&databricks.Config{ + Host: authArguments.Host, + AccountID: authArguments.AccountID, + Credentials: config.NewTokenSourceStrategy("login-token", authconv.AuthTokenSource(persistentAuth)), + }) + if err != nil { + return "", err + } + + workspaces, err := a.Workspaces.List(ctx) + if err != nil { + return "", err + } + + if len(workspaces) == 0 { + return "", nil + } + + const maxWorkspaces = 50 + if len(workspaces) > maxWorkspaces { + cmdio.LogString(ctx, fmt.Sprintf("Account has %d workspaces. Showing first %d. Use --workspace-id to specify directly.", len(workspaces), maxWorkspaces)) + workspaces = workspaces[:maxWorkspaces] + } + + if len(workspaces) == 1 { + wsID := strconv.FormatInt(workspaces[0].WorkspaceId, 10) + cmdio.LogString(ctx, fmt.Sprintf("Auto-selected workspace %q (%s)", workspaces[0].WorkspaceName, wsID)) + return wsID, nil + } + + items := make([]cmdio.Tuple, 0, len(workspaces)+1) + for _, ws := range workspaces { + items = append(items, cmdio.Tuple{ + Name: ws.WorkspaceName, + Id: strconv.FormatInt(ws.WorkspaceId, 10), + }) + } + // Allow skipping workspace selection for account-level access. + items = append(items, cmdio.Tuple{ + Name: "Skip (account-level access only)", + Id: "", + }) + + selected, err := cmdio.SelectOrdered(ctx, items, "Select a workspace") + if err != nil { + return "", err + } + return selected, nil +} + // getBrowserFunc returns a function that opens the given URL in the browser. // It respects the BROWSER environment variable: // - empty string: uses the default browser diff --git a/cmd/auth/login_test.go b/cmd/auth/login_test.go index bd135bc730..6200db6684 100644 --- a/cmd/auth/login_test.go +++ b/cmd/auth/login_test.go @@ -2,6 +2,9 @@ package auth import ( "context" + "encoding/json" + "net/http" + "net/http/httptest" "testing" "github.com/databricks/cli/libs/auth" @@ -267,3 +270,206 @@ func TestLoadProfileByNameAndClusterID(t *testing.T) { }) } } + +func TestExtractHostQueryParams(t *testing.T) { + tests := []struct { + name string + host string + existingAcctID string + existingWsID string + wantHost string + wantAccountID string + wantWorkspaceID string + }{ + { + name: "extract workspace_id from ?o=", + host: "https://spog.example.com/?o=12345", + wantHost: "https://spog.example.com", + wantWorkspaceID: "12345", + }, + { + name: "extract both account_id and workspace_id", + host: "https://spog.example.com/?o=12345&a=abc", + wantHost: "https://spog.example.com", + wantAccountID: "abc", + wantWorkspaceID: "12345", + }, + { + name: "extract account_id from ?account_id=", + host: "https://spog.example.com/?account_id=abc", + wantHost: "https://spog.example.com", + wantAccountID: "abc", + }, + { + name: "extract workspace_id from ?workspace_id=", + host: "https://spog.example.com/?workspace_id=99999", + wantHost: "https://spog.example.com", + wantWorkspaceID: "99999", + }, + { + name: "no query params leaves host unchanged", + host: "https://spog.example.com", + wantHost: "https://spog.example.com", + }, + { + name: "explicit flags take precedence over query params", + host: "https://spog.example.com/?o=12345&a=abc", + existingAcctID: "explicit-account", + existingWsID: "explicit-ws", + wantHost: "https://spog.example.com", + wantAccountID: "explicit-account", + wantWorkspaceID: "explicit-ws", + }, + { + name: "non-numeric ?o= is skipped", + host: "https://spog.example.com/?o=abc", + wantHost: "https://spog.example.com", + }, + { + name: "non-numeric ?workspace_id= is skipped", + host: "https://spog.example.com/?workspace_id=abc", + wantHost: "https://spog.example.com", + wantWorkspaceID: "", + }, + { + name: "invalid URL is left unchanged", + host: "not a valid url ://???", + wantHost: "not a valid url ://???", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + args := &auth.AuthArguments{ + Host: tt.host, + AccountID: tt.existingAcctID, + WorkspaceID: tt.existingWsID, + } + extractHostQueryParams(args) + assert.Equal(t, tt.wantHost, args.Host) + assert.Equal(t, tt.wantAccountID, args.AccountID) + assert.Equal(t, tt.wantWorkspaceID, args.WorkspaceID) + }) + } +} + +func TestRunHostDiscovery_NoHost(t *testing.T) { + ctx := t.Context() + args := &auth.AuthArguments{} + runHostDiscovery(ctx, args) + assert.Equal(t, "", args.AccountID) + assert.Equal(t, "", args.WorkspaceID) +} + +func TestRunHostDiscovery_ExplicitFieldsNotOverridden(t *testing.T) { + ctx := t.Context() + args := &auth.AuthArguments{ + Host: "https://nonexistent.example.com", + AccountID: "explicit-account", + WorkspaceID: "explicit-ws", + } + runHostDiscovery(ctx, args) + // Explicit fields should not be overridden even if discovery would return values + assert.Equal(t, "explicit-account", args.AccountID) + assert.Equal(t, "explicit-ws", args.WorkspaceID) +} + +// newDiscoveryServer creates a test HTTP server that responds to +// .well-known/databricks-config with the given metadata. +func newDiscoveryServer(t *testing.T, metadata map[string]any) *httptest.Server { + t.Helper() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/databricks-config" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(metadata); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + t.Cleanup(server.Close) + return server +} + +func TestRunHostDiscovery_SPOGHost(t *testing.T) { + server := newDiscoveryServer(t, map[string]any{ + "account_id": "discovered-account", + "workspace_id": "discovered-ws", + "oidc_endpoint": "https://spog.example.com/oidc/accounts/discovered-account", + }) + + ctx := t.Context() + args := &auth.AuthArguments{Host: server.URL} + runHostDiscovery(ctx, args) + + assert.Equal(t, "discovered-account", args.AccountID) + assert.Equal(t, "discovered-ws", args.WorkspaceID) +} + +func TestRunHostDiscovery_ClassicWorkspaceDoesNotSetAccountID(t *testing.T) { + // Classic workspace discovery returns workspace-scoped OIDC (no account in path). + server := newDiscoveryServer(t, map[string]any{ + "workspace_id": "12345", + "oidc_endpoint": "https://ws.example.com/oidc", + }) + + ctx := t.Context() + args := &auth.AuthArguments{Host: server.URL} + runHostDiscovery(ctx, args) + + // Only workspace_id is set; account_id stays empty since discovery didn't return it. + assert.Equal(t, "", args.AccountID) + assert.Equal(t, "12345", args.WorkspaceID) +} + +func TestExtractHostQueryParams_OverridesProfileWorkspaceID(t *testing.T) { + // Simulates the fix: profile loads workspace_id="old-ws", then the user + // provides --host https://spog.example.com?o=99999. After Fix 1, profile + // inheritance is deferred, so authArguments.WorkspaceID is empty when + // extractHostQueryParams runs, and URL param wins. + args := &auth.AuthArguments{ + Host: "https://spog.example.com/?o=99999", + // WorkspaceID is empty because profile inheritance was deferred. + } + extractHostQueryParams(args) + assert.Equal(t, "https://spog.example.com", args.Host) + assert.Equal(t, "99999", args.WorkspaceID) +} + +func TestSetHostAndAccountId_WorkspaceIDNoneSentinelInherited(t *testing.T) { + t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") + ctx, _ := cmdio.SetupTest(t.Context(), cmdio.TestOptions{}) + + skipProfile := loadTestProfile(t, ctx, "spog-skip-workspace") + + // When loading from a profile with workspace_id=none, the sentinel should + // be inherited and the workspace prompt should not fire. + args := auth.AuthArguments{ + Host: "https://spog.example.com", + AccountID: "spog-account", + } + err := setHostAndAccountId(ctx, skipProfile, &args, []string{}) + assert.NoError(t, err) + assert.Equal(t, auth.WorkspaceIDNone, args.WorkspaceID) +} + +func TestSetHostAndAccountId_URLParamsOverrideProfile(t *testing.T) { + t.Setenv("DATABRICKS_CONFIG_FILE", "./testdata/.databrickscfg") + ctx, _ := cmdio.SetupTest(t.Context(), cmdio.TestOptions{}) + + unifiedWorkspaceProfile := loadTestProfile(t, ctx, "unified-workspace") + + // The profile has workspace_id=123456789, but the URL has ?o=99999. + // URL params should win over profile values. + args := auth.AuthArguments{ + Host: "https://unified.databricks.com?o=99999", + AccountID: "test-unified-account", + IsUnifiedHost: true, + } + err := setHostAndAccountId(ctx, unifiedWorkspaceProfile, &args, []string{}) + assert.NoError(t, err) + assert.Equal(t, "https://unified.databricks.com", args.Host) + assert.Equal(t, "99999", args.WorkspaceID) +} diff --git a/cmd/auth/testdata/.databrickscfg b/cmd/auth/testdata/.databrickscfg index ca1a063076..fe836a53b4 100644 --- a/cmd/auth/testdata/.databrickscfg +++ b/cmd/auth/testdata/.databrickscfg @@ -26,3 +26,8 @@ experimental_is_unified_host = true host = https://unified.databricks.com account_id = test-unified-account experimental_is_unified_host = true + +[spog-skip-workspace] +host = https://spog.example.com +account_id = spog-account +workspace_id = none diff --git a/cmd/auth/token.go b/cmd/auth/token.go index 5f695ce6bc..0df233168b 100644 --- a/cmd/auth/token.go +++ b/cmd/auth/token.go @@ -37,7 +37,9 @@ const ( ) // applyUnifiedHostFlags copies unified host fields from the profile to the -// auth arguments when they are not already set. +// auth arguments when they are not already set. WorkspaceID is NOT copied +// here; it is deferred to setHostAndAccountId() so that URL query params +// (?o=...) can override stale profile values. func applyUnifiedHostFlags(p *profile.Profile, args *auth.AuthArguments) { if p == nil { return @@ -45,9 +47,6 @@ func applyUnifiedHostFlags(p *profile.Profile, args *auth.AuthArguments) { if !args.IsUnifiedHost && p.IsUnifiedHost { args.IsUnifiedHost = p.IsUnifiedHost } - if args.WorkspaceID == "" && p.WorkspaceID != "" { - args.WorkspaceID = p.WorkspaceID - } } func newTokenCommand(authArguments *auth.AuthArguments) *cobra.Command { @@ -176,19 +175,15 @@ func loadToken(ctx context.Context, args loadTokenArgs) (*oauth2.Token, error) { // primary key. Once older SDKs have migrated to profile-based keys, // dualWrite and the host key can be removed entirely. if args.profileName == "" && args.authArguments.Host != "" { - cfg := &config.Config{ - Host: args.authArguments.Host, - AccountID: args.authArguments.AccountID, - Experimental_IsUnifiedHost: args.authArguments.IsUnifiedHost, - } - // Canonicalize first so HostType() can correctly identify account hosts - // even when the host string lacks a scheme (e.g. "accounts.cloud.databricks.com"). - cfg.CanonicalHostName() + // Match profiles by host and available identifiers. For SPOG workspace + // profiles (host + account_id + workspace_id), use all three to + // disambiguate between workspaces sharing the same host and account. var matchFn profile.ProfileMatchFunction - switch cfg.HostType() { - case config.AccountHost, config.UnifiedHost: + if args.authArguments.AccountID != "" && args.authArguments.WorkspaceID != "" { + matchFn = profile.WithHostAccountIDAndWorkspaceID(args.authArguments.Host, args.authArguments.AccountID, args.authArguments.WorkspaceID) + } else if args.authArguments.AccountID != "" { matchFn = profile.WithHostAndAccountID(args.authArguments.Host, args.authArguments.AccountID) - default: + } else { matchFn = profile.WithHost(args.authArguments.Host) } diff --git a/libs/auth/arguments.go b/libs/auth/arguments.go index 8e00d89507..ee0dc6f21e 100644 --- a/libs/auth/arguments.go +++ b/libs/auth/arguments.go @@ -1,12 +1,16 @@ package auth import ( - "fmt" + "strings" "github.com/databricks/databricks-sdk-go/config" "github.com/databricks/databricks-sdk-go/credentials/u2m" ) +// WorkspaceIDNone is a sentinel value persisted to .databrickscfg when the +// user explicitly skips workspace selection for SPOG account-level access. +const WorkspaceIDNone = "none" + // AuthArguments is a struct that contains the common arguments passed to // `databricks auth` commands. type AuthArguments struct { @@ -18,28 +22,63 @@ type AuthArguments struct { // Profile is the optional profile name. When set, the OAuth token cache // key is the profile name instead of the host-based key. Profile string + + // DiscoveryURL is cached from host metadata discovery to avoid duplicate + // network calls when both runHostDiscovery and ToOAuthArgument need it. + DiscoveryURL string } // ToOAuthArgument converts the AuthArguments to an OAuthArgument from the Go SDK. +// It calls EnsureResolved() to run host metadata discovery and routes based on +// the resolved DiscoveryURL rather than the Experimental_IsUnifiedHost flag. func (a AuthArguments) ToOAuthArgument() (u2m.OAuthArgument, error) { + // Strip the "none" sentinel so it is never passed to the SDK. + workspaceID := a.WorkspaceID + if workspaceID == WorkspaceIDNone { + workspaceID = "" + } + cfg := &config.Config{ Host: a.Host, AccountID: a.AccountID, - WorkspaceID: a.WorkspaceID, + WorkspaceID: workspaceID, Experimental_IsUnifiedHost: a.IsUnifiedHost, + HTTPTimeoutSeconds: 5, + // Skip config file loading. We only want host metadata resolution + // based on the explicit fields provided. + Loaders: []config.Loader{config.ConfigAttributes}, } + + discoveryURL := a.DiscoveryURL + if discoveryURL == "" { + // No cached discovery, resolve fresh. + if err := cfg.EnsureResolved(); err == nil { + discoveryURL = cfg.DiscoveryURL + } + } + host := cfg.CanonicalHostName() - switch cfg.HostType() { - case config.AccountHost: + // Classic accounts.* hosts always use account OAuth, even if discovery + // returned data. This preserves backward compatibility. + if (&config.Config{Host: host}).HostType() == config.AccountHost { return u2m.NewProfileAccountOAuthArgument(host, cfg.AccountID, a.Profile) - case config.WorkspaceHost: - return u2m.NewProfileWorkspaceOAuthArgument(host, a.Profile) - case config.UnifiedHost: - // For unified hosts, always use the unified OAuth argument with account ID. - // The workspace ID is stored in the config for API routing, not OAuth. + } + + // Route based on discovery data: a non-accounts host with an account-scoped + // OIDC endpoint is a SPOG/unified host. We check a.AccountID (the caller- + // provided value) rather than cfg.AccountID to avoid env var contamination + // (e.g. DATABRICKS_ACCOUNT_ID set in the environment). We also require the + // DiscoveryURL to contain "/oidc/accounts/" to distinguish SPOG hosts from + // classic workspace hosts that may also return discovery metadata. + if a.AccountID != "" && discoveryURL != "" && strings.Contains(discoveryURL, "/oidc/accounts/") { + return u2m.NewProfileUnifiedOAuthArgument(host, cfg.AccountID, a.Profile) + } + + // Legacy backward compat: existing profiles with IsUnifiedHost flag. + if a.IsUnifiedHost && a.AccountID != "" { return u2m.NewProfileUnifiedOAuthArgument(host, cfg.AccountID, a.Profile) - default: - return nil, fmt.Errorf("unknown host type: %v", cfg.HostType()) } + + return u2m.NewProfileWorkspaceOAuthArgument(host, a.Profile) } diff --git a/libs/auth/arguments_test.go b/libs/auth/arguments_test.go index 7b41b9dbfd..415e87c0dd 100644 --- a/libs/auth/arguments_test.go +++ b/libs/auth/arguments_test.go @@ -1,10 +1,14 @@ package auth import ( + "encoding/json" + "net/http" + "net/http/httptest" "testing" "github.com/databricks/databricks-sdk-go/credentials/u2m" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestToOAuthArgument(t *testing.T) { @@ -116,6 +120,18 @@ func TestToOAuthArgument(t *testing.T) { wantHost: "https://unified.cloud.databricks.com", wantCacheKey: "my-unified-profile", }, + { + name: "workspace_id none sentinel is stripped", + args: AuthArguments{ + Host: "https://unified.cloud.databricks.com", + AccountID: "123456789", + WorkspaceID: "none", + IsUnifiedHost: true, + Profile: "my-profile", + }, + wantHost: "https://unified.cloud.databricks.com", + wantCacheKey: "my-profile", + }, } for _, tt := range tests { @@ -145,3 +161,97 @@ func TestToOAuthArgument(t *testing.T) { }) } } + +func TestToOAuthArgument_SPOGHostRoutesToUnified(t *testing.T) { + // A SPOG host returns an account-scoped OIDC endpoint from discovery. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/databricks-config" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "account_id": "spog-account", + "workspace_id": "spog-ws", + "oidc_endpoint": r.Host + "/oidc/accounts/spog-account", + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + args := AuthArguments{ + Host: server.URL, + AccountID: "spog-account", + } + got, err := args.ToOAuthArgument() + require.NoError(t, err) + + // Should route to unified OAuth. + _, ok := got.(u2m.UnifiedOAuthArgument) + assert.True(t, ok, "expected UnifiedOAuthArgument for SPOG host, got %T", got) +} + +func TestToOAuthArgument_ClassicWorkspaceNotMisrouted(t *testing.T) { + // A classic workspace host returns workspace-scoped OIDC (no /accounts/ in path). + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/databricks-config" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "workspace_id": "12345", + "oidc_endpoint": r.Host + "/oidc", + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + // Even with AccountID set (from env or caller), a classic workspace host + // should NOT be routed to unified OAuth. + args := AuthArguments{ + Host: server.URL, + AccountID: "some-account", + } + got, err := args.ToOAuthArgument() + require.NoError(t, err) + + // Should route to workspace OAuth, not unified. + _, ok := got.(u2m.WorkspaceOAuthArgument) + assert.True(t, ok, "expected WorkspaceOAuthArgument for classic workspace, got %T", got) +} + +func TestToOAuthArgument_NoAccountIDSkipsUnifiedRouting(t *testing.T) { + // Even if discovery returns an account-scoped OIDC URL, without an explicit + // AccountID from the caller, unified routing should NOT be triggered. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/databricks-config" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "account_id": "discovered-account", + "oidc_endpoint": r.Host + "/oidc/accounts/discovered-account", + }); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + args := AuthArguments{ + Host: server.URL, + // No AccountID set by caller. + } + got, err := args.ToOAuthArgument() + require.NoError(t, err) + + // Should route to workspace OAuth because caller didn't provide AccountID. + _, ok := got.(u2m.WorkspaceOAuthArgument) + assert.True(t, ok, "expected WorkspaceOAuthArgument when no caller AccountID, got %T", got) +} diff --git a/libs/auth/credentials.go b/libs/auth/credentials.go index 6adf0ab9c0..7ab6eb2a85 100644 --- a/libs/auth/credentials.go +++ b/libs/auth/credentials.go @@ -129,5 +129,6 @@ func authArgumentsFromConfig(cfg *config.Config) AuthArguments { WorkspaceID: cfg.WorkspaceID, IsUnifiedHost: cfg.Experimental_IsUnifiedHost, Profile: cfg.Profile, + DiscoveryURL: cfg.DiscoveryURL, } } diff --git a/libs/auth/error.go b/libs/auth/error.go index 674083313e..ea26651e04 100644 --- a/libs/auth/error.go +++ b/libs/auth/error.go @@ -134,9 +134,10 @@ func writeReauthSteps(ctx context.Context, cfg *config.Config, b *strings.Builde return } loginCmd := BuildLoginCommand(ctx, "", oauthArg) - // For unified hosts, BuildLoginCommand (via OAuthArgument) doesn't carry - // workspace-id. Append it so the command is actionable. - if cfg.Experimental_IsUnifiedHost && cfg.WorkspaceID != "" { + // For SPOG/unified hosts, the OAuthArgument doesn't carry workspace-id. + // Append it so the re-auth command is actionable. Only for hosts with + // both account_id and workspace_id (i.e. SPOG/unified hosts). + if cfg.WorkspaceID != "" && cfg.AccountID != "" { loginCmd += " --workspace-id " + cfg.WorkspaceID } fmt.Fprintf(b, "\n - Re-authenticate: %s", loginCmd) @@ -178,6 +179,9 @@ func BuildLoginCommand(ctx context.Context, profile string, arg u2m.OAuthArgumen } else { switch arg := arg.(type) { case u2m.UnifiedOAuthArgument: + // The --experimental-is-unified-host flag is redundant now that + // discovery handles routing, but kept for backward compatibility + // until the flag is fully removed. cmd = append(cmd, "--host", arg.GetHost(), "--account-id", arg.GetAccountId(), "--experimental-is-unified-host") case u2m.AccountOAuthArgument: cmd = append(cmd, "--host", arg.GetAccountHost(), "--account-id", arg.GetAccountId()) diff --git a/libs/databrickscfg/profile/profiler.go b/libs/databrickscfg/profile/profiler.go index 8eff2675b9..af99794799 100644 --- a/libs/databrickscfg/profile/profiler.go +++ b/libs/databrickscfg/profile/profiler.go @@ -3,23 +3,25 @@ package profile import ( "context" + "github.com/databricks/cli/libs/auth" "github.com/databricks/databricks-sdk-go/config" ) type ProfileMatchFunction func(Profile) bool func MatchWorkspaceProfiles(p Profile) bool { - // Match workspace profiles: regular workspace profiles (no account ID) - // or unified hosts with workspace ID - return (p.AccountID == "" && !p.IsUnifiedHost) || - (p.IsUnifiedHost && p.WorkspaceID != "") + // Workspace profile: has workspace_id (covers both classic and SPOG profiles), + // or is a regular workspace host (no account_id and not a legacy unified-host profile). + // workspace_id = "none" is a sentinel for "skip workspace", so it does NOT count. + return (p.WorkspaceID != "" && p.WorkspaceID != auth.WorkspaceIDNone) || (p.AccountID == "" && !p.IsUnifiedHost) } func MatchAccountProfiles(p Profile) bool { - // Match account profiles: regular account profiles (with account ID) - // or unified hosts with account ID but no workspace ID - return (p.Host != "" && p.AccountID != "" && !p.IsUnifiedHost) || - (p.IsUnifiedHost && p.AccountID != "" && p.WorkspaceID == "") + // Account profile: has host and account_id but no workspace_id. + // workspace_id = "none" is a sentinel for account-level access, treated as empty. + // This covers classic accounts.* profiles, legacy unified-host account profiles, + // and new SPOG account profiles. + return p.Host != "" && p.AccountID != "" && (p.WorkspaceID == "" || p.WorkspaceID == auth.WorkspaceIDNone) } func MatchAllProfiles(p Profile) bool { @@ -62,6 +64,17 @@ func WithHostAndAccountID(host, accountID string) ProfileMatchFunction { } } +// WithHostAccountIDAndWorkspaceID returns a ProfileMatchFunction that matches +// profiles by canonical host, account ID, and workspace ID. This is used for +// SPOG workspace profiles where multiple workspaces share the same host and +// account ID. +func WithHostAccountIDAndWorkspaceID(host, accountID, workspaceID string) ProfileMatchFunction { + target := canonicalizeHost(host) + return func(p Profile) bool { + return p.Host != "" && canonicalizeHost(p.Host) == target && p.AccountID == accountID && p.WorkspaceID == workspaceID + } +} + // canonicalizeHost normalizes a host using the SDK's canonical host logic. func canonicalizeHost(host string) string { return (&config.Config{Host: host}).CanonicalHostName() diff --git a/libs/databrickscfg/profile/profiler_test.go b/libs/databrickscfg/profile/profiler_test.go index aa13e76a46..66db4dcbb5 100644 --- a/libs/databrickscfg/profile/profiler_test.go +++ b/libs/databrickscfg/profile/profiler_test.go @@ -123,3 +123,169 @@ func TestWithHostAndAccountID(t *testing.T) { }) } } + +func TestWithHostAccountIDAndWorkspaceID(t *testing.T) { + cases := []struct { + name string + inputHost string + inputAccountID string + inputWorkspaceID string + profileHost string + profileAccountID string + profileWorkspaceID string + want bool + }{ + { + name: "all three match", + inputHost: "https://spog.example.com", + inputAccountID: "acc-1", + inputWorkspaceID: "ws-1", + profileHost: "https://spog.example.com", + profileAccountID: "acc-1", + profileWorkspaceID: "ws-1", + want: true, + }, + { + name: "different workspace_id", + inputHost: "https://spog.example.com", + inputAccountID: "acc-1", + inputWorkspaceID: "ws-1", + profileHost: "https://spog.example.com", + profileAccountID: "acc-1", + profileWorkspaceID: "ws-2", + want: false, + }, + { + name: "different account_id", + inputHost: "https://spog.example.com", + inputAccountID: "acc-1", + inputWorkspaceID: "ws-1", + profileHost: "https://spog.example.com", + profileAccountID: "acc-2", + profileWorkspaceID: "ws-1", + want: false, + }, + { + name: "different host", + inputHost: "https://other.example.com", + inputAccountID: "acc-1", + inputWorkspaceID: "ws-1", + profileHost: "https://spog.example.com", + profileAccountID: "acc-1", + profileWorkspaceID: "ws-1", + want: false, + }, + { + name: "empty host on profile", + inputHost: "https://spog.example.com", + inputAccountID: "acc-1", + inputWorkspaceID: "ws-1", + profileHost: "", + profileAccountID: "acc-1", + profileWorkspaceID: "ws-1", + want: false, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + p := Profile{Host: c.profileHost, AccountID: c.profileAccountID, WorkspaceID: c.profileWorkspaceID} + fn := WithHostAccountIDAndWorkspaceID(c.inputHost, c.inputAccountID, c.inputWorkspaceID) + assert.Equal(t, c.want, fn(p)) + }) + } +} + +func TestMatchWorkspaceProfiles(t *testing.T) { + tests := []struct { + name string + profile Profile + want bool + }{ + { + name: "regular workspace (no account_id)", + profile: Profile{Host: "https://ws.cloud.databricks.com"}, + want: true, + }, + { + name: "SPOG workspace (has workspace_id)", + profile: Profile{Host: "https://spog.example.com", AccountID: "acc-1", WorkspaceID: "ws-1"}, + want: true, + }, + { + name: "legacy unified workspace (has workspace_id and IsUnifiedHost)", + profile: Profile{Host: "https://unified.example.com", AccountID: "acc-1", WorkspaceID: "ws-1", IsUnifiedHost: true}, + want: true, + }, + { + name: "regular account profile (has account_id, no workspace_id)", + profile: Profile{Host: "https://accounts.cloud.databricks.com", AccountID: "acc-1"}, + want: false, + }, + { + name: "legacy unified account (IsUnifiedHost, no workspace_id)", + profile: Profile{Host: "https://unified.example.com", AccountID: "acc-1", IsUnifiedHost: true}, + want: false, + }, + { + name: "workspace_id none sentinel is not a workspace profile", + profile: Profile{Host: "https://spog.example.com", AccountID: "acc-1", WorkspaceID: "none"}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, MatchWorkspaceProfiles(tt.profile)) + }) + } +} + +func TestMatchAccountProfiles(t *testing.T) { + tests := []struct { + name string + profile Profile + want bool + }{ + { + name: "regular account profile", + profile: Profile{Host: "https://accounts.cloud.databricks.com", AccountID: "acc-1"}, + want: true, + }, + { + name: "SPOG account profile (account_id, no workspace_id)", + profile: Profile{Host: "https://spog.example.com", AccountID: "acc-1"}, + want: true, + }, + { + name: "legacy unified account profile", + profile: Profile{Host: "https://unified.example.com", AccountID: "acc-1", IsUnifiedHost: true}, + want: true, + }, + { + name: "workspace_id none sentinel matches as account profile", + profile: Profile{Host: "https://spog.example.com", AccountID: "acc-1", WorkspaceID: "none"}, + want: true, + }, + { + name: "SPOG workspace profile (has workspace_id)", + profile: Profile{Host: "https://spog.example.com", AccountID: "acc-1", WorkspaceID: "ws-1"}, + want: false, + }, + { + name: "regular workspace (no account_id)", + profile: Profile{Host: "https://ws.cloud.databricks.com"}, + want: false, + }, + { + name: "no host", + profile: Profile{AccountID: "acc-1"}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, MatchAccountProfiles(tt.profile)) + }) + } +}