-
Notifications
You must be signed in to change notification settings - Fork 148
Auto-generate session names for serverless SSH connect #4701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: ssh-connect-elapsed-time
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ import ( | |
|
|
||
| "github.com/databricks/cli/experimental/ssh/internal/keys" | ||
| "github.com/databricks/cli/experimental/ssh/internal/proxy" | ||
| "github.com/databricks/cli/experimental/ssh/internal/sessions" | ||
| "github.com/databricks/cli/experimental/ssh/internal/sshconfig" | ||
| "github.com/databricks/cli/experimental/ssh/internal/vscode" | ||
| sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace" | ||
|
|
@@ -99,11 +100,11 @@ type ClientOptions struct { | |
| } | ||
|
|
||
| func (o *ClientOptions) Validate() error { | ||
| if !o.ProxyMode && o.ClusterID == "" && o.ConnectionName == "" { | ||
| return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)") | ||
| if !o.ProxyMode && o.ClusterID == "" && o.ConnectionName == "" && o.Accelerator == "" { | ||
| return errors.New("please provide --cluster or --accelerator flag") | ||
| } | ||
| if o.Accelerator != "" && o.ConnectionName == "" { | ||
| return errors.New("--accelerator flag can only be used with serverless compute (--name flag)") | ||
| if o.Accelerator != "" && o.ClusterID != "" { | ||
| return errors.New("--accelerator flag can only be used with serverless compute, not with --cluster") | ||
| } | ||
| // TODO: Remove when we add support for serverless CPU | ||
| if o.ConnectionName != "" && o.Accelerator == "" { | ||
|
|
@@ -122,7 +123,7 @@ func (o *ClientOptions) Validate() error { | |
| } | ||
|
|
||
| func (o *ClientOptions) IsServerlessMode() bool { | ||
| return o.ClusterID == "" && o.ConnectionName != "" | ||
| return o.ClusterID == "" && (o.ConnectionName != "" || o.Accelerator != "") | ||
| } | ||
|
|
||
| // SessionIdentifier returns the unique identifier for the session. | ||
|
|
@@ -202,9 +203,17 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt | |
| cancel() | ||
| }() | ||
|
|
||
| // For serverless without explicit --name: auto-generate or reconnect to existing session. | ||
| if opts.IsServerlessMode() && opts.ConnectionName == "" && !opts.ProxyMode { | ||
| err := resolveServerlessSession(ctx, client, &opts) | ||
| if err != nil { | ||
| return err | ||
| } | ||
| } | ||
|
|
||
| sessionID := opts.SessionIdentifier() | ||
| if sessionID == "" { | ||
| return errors.New("either --cluster or --name must be provided") | ||
| return errors.New("either --cluster or --accelerator must be provided") | ||
| } | ||
|
|
||
| if !opts.ProxyMode { | ||
|
|
@@ -327,6 +336,20 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt | |
| cmdio.LogString(ctx, "Connected!") | ||
| } | ||
|
|
||
| // Persist the session for future reconnects. | ||
| if opts.IsServerlessMode() && !opts.ProxyMode { | ||
| err = sessions.Add(ctx, sessions.Session{ | ||
| Name: opts.ConnectionName, | ||
| Accelerator: opts.Accelerator, | ||
| WorkspaceHost: client.Config.Host, | ||
| CreatedAt: time.Now(), | ||
| ClusterID: clusterID, | ||
| }) | ||
| if err != nil { | ||
| log.Warnf(ctx, "Failed to save session state: %v", err) | ||
| } | ||
| } | ||
|
|
||
| if opts.ProxyMode { | ||
| return runSSHProxy(ctx, client, serverPort, clusterID, opts) | ||
| } else if opts.IDE != "" { | ||
|
|
@@ -379,7 +402,12 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k | |
| return fmt.Errorf("failed to generate ProxyCommand: %w", err) | ||
| } | ||
|
|
||
| hostConfig := sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand) | ||
| var hostConfig string | ||
| if opts.IsServerlessMode() { | ||
| hostConfig = sshconfig.GenerateServerlessHostConfig(hostName, userName, keyPath, proxyCommand) | ||
| } else { | ||
| hostConfig = sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand) | ||
| } | ||
|
|
||
| _, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true) | ||
| if err != nil { | ||
|
|
@@ -547,15 +575,22 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server | |
|
|
||
| hostName := opts.SessionIdentifier() | ||
|
|
||
| hostKeyChecking := "StrictHostKeyChecking=accept-new" | ||
| if opts.IsServerlessMode() { | ||
| hostKeyChecking = "StrictHostKeyChecking=no" | ||
| } | ||
|
|
||
| sshArgs := []string{ | ||
| "-l", userName, | ||
| "-i", privateKeyPath, | ||
| "-o", "IdentitiesOnly=yes", | ||
| "-o", "StrictHostKeyChecking=accept-new", | ||
| "-o", hostKeyChecking, | ||
| "-o", "ConnectTimeout=360", | ||
| "-o", "ProxyCommand=" + proxyCommand, | ||
| } | ||
| if opts.UserKnownHostsFile != "" { | ||
| if opts.IsServerlessMode() { | ||
| sshArgs = append(sshArgs, "-o", "UserKnownHostsFile=/dev/null") | ||
| } else if opts.UserKnownHostsFile != "" { | ||
| sshArgs = append(sshArgs, "-o", "UserKnownHostsFile="+opts.UserKnownHostsFile) | ||
| } | ||
| sshArgs = append(sshArgs, hostName) | ||
|
|
@@ -703,3 +738,97 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC | |
|
|
||
| return userName, serverPort, effectiveClusterID, nil | ||
| } | ||
|
|
||
| // resolveServerlessSession handles auto-generation and reconnection for serverless sessions. | ||
| // It checks local state for existing sessions matching the workspace and accelerator, | ||
| // probes them to see if they're still alive, and prompts the user to reconnect or create new. | ||
| func resolveServerlessSession(ctx context.Context, client *databricks.WorkspaceClient, opts *ClientOptions) error { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit, but this can be a method on the ClientOptions struct, might be easier to understand that we are mutating the options here then |
||
| version := build.GetInfo().Version | ||
|
|
||
| matching, err := sessions.FindMatching(ctx, client.Config.Host, opts.Accelerator) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like majority of this logic can be moved to the sessions package (up until line 788). getServerMetadata can be passed as an argument. Then it will be easier to test. Same for cleanupStaleSession. Or will there be circular dependencies it we do that? (since that function has a lot of them) |
||
| if err != nil { | ||
| log.Warnf(ctx, "Failed to load session state: %v", err) | ||
| } | ||
|
|
||
| // Probe sessions to find alive ones (limit to 5 most recent to avoid latency). | ||
| const maxProbe = 5 | ||
| if len(matching) > maxProbe { | ||
| matching = matching[len(matching)-maxProbe:] | ||
| } | ||
|
|
||
| var alive []sessions.Session | ||
| for _, s := range matching { | ||
| _, _, _, probeErr := getServerMetadata(ctx, client, s.Name, s.ClusterID, version, opts.Liteswap) | ||
| if probeErr == nil { | ||
| alive = append(alive, s) | ||
| } else { | ||
| cleanupStaleSession(ctx, client, s, version) | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Agent Swarm Review] [Critical] Any probe error is treated as proof that the session is stale.
Both reviewers flagged this. Isaac confirmed Critical in cross-review due to irreversible blast radius. Suggestion: Only run destructive cleanup on definitive stale signals (e.g., 404/not-found). For transient errors, keep the session and surface a warning. |
||
| } | ||
|
|
||
| if len(alive) > 0 && cmdio.IsPromptSupported(ctx) { | ||
| choices := make([]string, 0, len(alive)+1) | ||
| for _, s := range alive { | ||
| choices = append(choices, fmt.Sprintf("Reconnect to %s (started %s)", s.Name, s.CreatedAt.Format(time.RFC822))) | ||
| } | ||
| choices = append(choices, "Create new session") | ||
|
|
||
| choice, choiceErr := cmdio.AskSelect(ctx, "Found existing sessions:", choices) | ||
| if choiceErr != nil { | ||
| return fmt.Errorf("failed to prompt user: %w", choiceErr) | ||
| } | ||
|
|
||
| for i, s := range alive { | ||
| if choice == choices[i] { | ||
| opts.ConnectionName = s.Name | ||
| cmdio.LogString(ctx, "Reconnecting to session: "+s.Name) | ||
| return nil | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // No alive session selected — generate a new name. | ||
| opts.ConnectionName = sessions.GenerateSessionName(opts.Accelerator) | ||
| cmdio.LogString(ctx, "Creating new session: "+opts.ConnectionName) | ||
| return nil | ||
| } | ||
|
|
||
| // cleanupStaleSession removes all local and remote artifacts for a stale session. | ||
| func cleanupStaleSession(ctx context.Context, client *databricks.WorkspaceClient, s sessions.Session, version string) { | ||
| // Remove local SSH keys. | ||
| keyPath, err := keys.GetLocalSSHKeyPath(ctx, s.Name, "") | ||
| if err == nil { | ||
| os.RemoveAll(filepath.Dir(keyPath)) | ||
| } | ||
|
|
||
| // Remove SSH config entry. | ||
| if err := sshconfig.RemoveHostConfig(ctx, s.Name); err != nil { | ||
| log.Debugf(ctx, "Failed to remove SSH config for %s: %v", s.Name, err) | ||
| } | ||
|
|
||
| // Delete secret scope (best-effort). | ||
| me, err := client.CurrentUser.Me(ctx) | ||
| if err == nil { | ||
| scopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, s.Name) | ||
| deleteErr := client.Secrets.DeleteScope(ctx, workspace.DeleteScope{Scope: scopeName}) | ||
| if deleteErr != nil { | ||
| log.Debugf(ctx, "Failed to delete secret scope %s: %v", scopeName, deleteErr) | ||
| } | ||
| } | ||
|
|
||
| // Remove workspace content directory (best-effort). | ||
| contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, s.Name) | ||
| if err == nil { | ||
| deleteErr := client.Workspace.Delete(ctx, workspace.Delete{Path: contentDir, Recursive: true}) | ||
| if deleteErr != nil { | ||
| log.Debugf(ctx, "Failed to delete workspace content for %s: %v", s.Name, deleteErr) | ||
| } | ||
| } | ||
|
|
||
| // Remove from local state. | ||
| if err := sessions.Remove(ctx, s.Name); err != nil { | ||
| log.Debugf(ctx, "Failed to remove session %s from state: %v", s.Name, err) | ||
| } | ||
|
|
||
| log.Infof(ctx, "Cleaned up stale session: %s", s.Name) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| package sessions | ||
|
|
||
| import ( | ||
| "crypto/rand" | ||
| "encoding/hex" | ||
| "strings" | ||
| "time" | ||
| ) | ||
|
|
||
| // acceleratorPrefixes maps known accelerator types to short human-readable prefixes. | ||
| var acceleratorPrefixes = map[string]string{ | ||
| "GPU_1xA10": "gpu-a10", | ||
| "GPU_8xH100": "gpu-h100", | ||
| } | ||
|
|
||
| // GenerateSessionName creates a human-readable session name from the accelerator type. | ||
| // Format: <prefix>-<random_hex>, e.g. "gpu-a10-f3a2b1c0". | ||
| func GenerateSessionName(accelerator string) string { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As mentioned above, will it help with known_hosts conflicts if we add a workspace id/host here? |
||
| prefix, ok := acceleratorPrefixes[accelerator] | ||
| if !ok { | ||
| prefix = strings.ToLower(strings.ReplaceAll(accelerator, "_", "-")) | ||
| } | ||
|
|
||
| date := time.Now().Format("20060102") | ||
| b := make([]byte, 3) | ||
| _, _ = rand.Read(b) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Agent Swarm Review] [Nit]
|
||
| return "databricks-" + prefix + "-" + date + "-" + hex.EncodeToString(b) | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've had such options before, but the security didn't like it.
With auto host name generation we should not have that many host conflicts, right?
Before you would get them if you re-used the same name to connect to a different workspace. Re-using the same name for the same workspace is fine, as our server will get the server ssh key from the secrets scope that's tied to the name (and with the same name the scope will be the same). But across different workspaces we will get a problem, since server keys will be different.
Can we also add workspace id (real one, or based on the host url) to the generated session name?