Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion cmd/apps/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -717,9 +717,29 @@ func runCreate(ctx context.Context, opts createOptions) error {
// Always include mandatory plugins regardless of user selection or flags.
selectedPlugins = appendUnique(selectedPlugins, m.GetMandatoryPluginNames()...)

// In flags/non-interactive mode, validate that all required resources are provided.
// In flags/non-interactive mode, resolve derived postgres values and validate resources.
if flagsMode || !isInteractive {
resources := m.CollectResources(selectedPlugins)

// Resolve derived values for resources that support it.
for _, r := range resources {
resolveFn, ok := prompt.GetResolveFunc(r.Type)
if !ok {
continue
}
resolved, err := resolveFn(ctx, r, resourceValues)
if err != nil {
log.Warnf(ctx, "Could not resolve derived values for %s: %v", r.Alias, err)
continue
}
for k, v := range resolved {
if resourceValues[k] == "" {
resourceValues[k] = v
}
}
}

// Validate that all required resources are provided.
for _, r := range resources {
found := false
for k := range resourceValues {
Expand Down
92 changes: 60 additions & 32 deletions libs/apps/prompt/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"maps"
"os"
"path/filepath"
"regexp"
Expand Down Expand Up @@ -564,29 +565,6 @@ func PromptForPostgres(ctx context.Context, r manifest.Resource, required bool)
return nil, nil
}

// Step 2.5: resolve endpoint details from the branch (non-fatal).
var host, endpointPath string
endpointErr := RunWithSpinnerCtx(ctx, "Resolving connection details...", func() error {
endpoints, fetchErr := ListPostgresEndpoints(ctx, branchName)
if fetchErr != nil {
return fetchErr
}
for _, ep := range endpoints {
if ep.Status != nil && ep.Status.EndpointType == postgres.EndpointTypeEndpointTypeReadWrite {
endpointPath = ep.Name
if ep.Status.Hosts != nil && ep.Status.Hosts.Host != "" {
host = ep.Status.Hosts.Host
}
break
}
}
return nil
})
if endpointErr != nil {
log.Warnf(ctx, "Could not resolve endpoint details: %v", endpointErr)
// non-fatal: user can fill values manually
}

// Step 3: pick a database within the branch
var databases []ListItem
err = RunWithSpinnerCtx(ctx, "Fetching databases...", func() error {
Expand All @@ -605,22 +583,72 @@ func PromptForPostgres(ctx context.Context, r manifest.Resource, required bool)
return nil, nil
}

// Build resolver results map keyed by resolver name.
resolvedValues := map[string]string{
"postgres:host": host,
"postgres:databaseName": pgDatabaseName,
"postgres:endpointPath": endpointPath,
}

// Start with prompted values (fields without resolve).
result := map[string]string{
r.Key() + ".branch": branchName,
r.Key() + ".database": dbName,
}

// Map resolved values to fields using the manifest's resolve property.
applyResolvedValues(r, resolvedValues, result)
// Resolve derived values (host, databaseName, endpointPath) — non-fatal.
resolved, resolveErr := ResolvePostgresValues(ctx, r, branchName, dbName, pgDatabaseName)
if resolveErr != nil {
log.Warnf(ctx, "Could not resolve connection details: %v", resolveErr)
}
maps.Copy(result, resolved)

return result, nil
}

// resolvePostgresResource adapts ResolvePostgresValues for the generic ResolveResourceFunc signature.
func resolvePostgresResource(ctx context.Context, r manifest.Resource, provided map[string]string) (map[string]string, error) {
branchName := provided[r.Key()+".branch"]
dbName := provided[r.Key()+".database"]
if branchName == "" || dbName == "" {
return nil, nil
}
return ResolvePostgresValues(ctx, r, branchName, dbName, "")
}

// ResolvePostgresValues resolves derived field values (host, databaseName, endpointPath)
// from a branch and database resource name. If pgDatabaseName is already known
// (e.g. from a prior prompt), pass it to skip the ListDatabases API call.
func ResolvePostgresValues(ctx context.Context, r manifest.Resource, branchName, dbName, pgDatabaseName string) (map[string]string, error) {
var host, endpointPath string
endpoints, err := ListPostgresEndpoints(ctx, branchName)
if err != nil {
return nil, fmt.Errorf("resolving endpoint details: %w", err)
}
for _, ep := range endpoints {
if ep.Status != nil && ep.Status.EndpointType == postgres.EndpointTypeEndpointTypeReadWrite {
endpointPath = ep.Name
if ep.Status.Hosts != nil && ep.Status.Hosts.Host != "" {
host = ep.Status.Hosts.Host
}
break
}
}

if pgDatabaseName == "" {
databases, err := ListPostgresDatabases(ctx, branchName)
if err != nil {
return nil, fmt.Errorf("resolving database name: %w", err)
}
for _, db := range databases {
if db.ID == dbName {
pgDatabaseName = db.Label
break
}
}
}

resolvedValues := map[string]string{
"postgres:host": host,
"postgres:databaseName": pgDatabaseName,
"postgres:endpointPath": endpointPath,
}

result := make(map[string]string)
applyResolvedValues(r, resolvedValues, result)
return result, nil
}

Expand Down
222 changes: 222 additions & 0 deletions libs/apps/prompt/resolve_postgres_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
package prompt

import (
"context"
"errors"
"testing"

"github.com/databricks/cli/libs/apps/manifest"
"github.com/databricks/cli/libs/cmdctx"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/databricks-sdk-go/experimental/mocks"
"github.com/databricks/databricks-sdk-go/listing"
"github.com/databricks/databricks-sdk-go/service/postgres"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

func newPostgresResource() manifest.Resource {
return manifest.Resource{
ResourceKey: "postgres",
Fields: map[string]manifest.ResourceField{
"branch": {Description: "branch path"},
"database": {Description: "database name"},
"host": {Resolve: "postgres:host"},
"databaseName": {Resolve: "postgres:databaseName"},
"endpointPath": {Resolve: "postgres:endpointPath"},
},
}
}

func TestResolvePostgresValuesHappyPath(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient)

branchName := "projects/p1/branches/main"
dbName := "projects/p1/branches/main/databases/mydb"

// Mock ListEndpoints
endpoints := listing.SliceIterator[postgres.Endpoint]{
{
Name: "projects/p1/branches/main/endpoints/ep1",
Status: &postgres.EndpointStatus{
EndpointType: postgres.EndpointTypeEndpointTypeReadWrite,
Hosts: &postgres.EndpointHosts{Host: "my-host.example.com"},
},
},
}
m.GetMockPostgresAPI().EXPECT().
ListEndpoints(mock.Anything, postgres.ListEndpointsRequest{Parent: branchName}).
Return(&endpoints).Once()

// Mock ListDatabases
databases := listing.SliceIterator[postgres.Database]{
{
Name: dbName,
Status: &postgres.DatabaseDatabaseStatus{PostgresDatabase: "my_pg_db"},
},
}
m.GetMockPostgresAPI().EXPECT().
ListDatabases(mock.Anything, postgres.ListDatabasesRequest{Parent: branchName}).
Return(&databases).Once()

r := newPostgresResource()
result, err := ResolvePostgresValues(ctx, r, branchName, dbName, "")
require.NoError(t, err)

assert.Equal(t, map[string]string{
"postgres.host": "my-host.example.com",
"postgres.databaseName": "my_pg_db",
"postgres.endpointPath": "projects/p1/branches/main/endpoints/ep1",
}, result)
}

func TestResolvePostgresValuesNoReadWriteEndpoint(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient)

branchName := "projects/p1/branches/main"
dbName := "projects/p1/branches/main/databases/mydb"

// Return only a read-only endpoint.
endpoints := listing.SliceIterator[postgres.Endpoint]{
{
Name: "projects/p1/branches/main/endpoints/ep1",
Status: &postgres.EndpointStatus{
EndpointType: postgres.EndpointTypeEndpointTypeReadOnly,
},
},
}
m.GetMockPostgresAPI().EXPECT().
ListEndpoints(mock.Anything, postgres.ListEndpointsRequest{Parent: branchName}).
Return(&endpoints).Once()

databases := listing.SliceIterator[postgres.Database]{
{
Name: dbName,
Status: &postgres.DatabaseDatabaseStatus{PostgresDatabase: "my_pg_db"},
},
}
m.GetMockPostgresAPI().EXPECT().
ListDatabases(mock.Anything, postgres.ListDatabasesRequest{Parent: branchName}).
Return(&databases).Once()

r := newPostgresResource()
result, err := ResolvePostgresValues(ctx, r, branchName, dbName, "")
require.NoError(t, err)

// host and endpointPath should be empty since no ReadWrite endpoint found.
assert.Equal(t, map[string]string{
"postgres.host": "",
"postgres.databaseName": "my_pg_db",
"postgres.endpointPath": "",
}, result)
}

func TestResolvePostgresValuesDatabaseNotFound(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient)

branchName := "projects/p1/branches/main"
dbName := "projects/p1/branches/main/databases/nonexistent"

endpoints := listing.SliceIterator[postgres.Endpoint]{
{
Name: "projects/p1/branches/main/endpoints/ep1",
Status: &postgres.EndpointStatus{
EndpointType: postgres.EndpointTypeEndpointTypeReadWrite,
Hosts: &postgres.EndpointHosts{Host: "my-host.example.com"},
},
},
}
m.GetMockPostgresAPI().EXPECT().
ListEndpoints(mock.Anything, postgres.ListEndpointsRequest{Parent: branchName}).
Return(&endpoints).Once()

// Return databases that don't match dbName.
databases := listing.SliceIterator[postgres.Database]{
{
Name: "projects/p1/branches/main/databases/other",
Status: &postgres.DatabaseDatabaseStatus{PostgresDatabase: "other_db"},
},
}
m.GetMockPostgresAPI().EXPECT().
ListDatabases(mock.Anything, postgres.ListDatabasesRequest{Parent: branchName}).
Return(&databases).Once()

r := newPostgresResource()
result, err := ResolvePostgresValues(ctx, r, branchName, dbName, "")
require.NoError(t, err)

// databaseName should be empty since no match.
assert.Equal(t, map[string]string{
"postgres.host": "my-host.example.com",
"postgres.databaseName": "",
"postgres.endpointPath": "projects/p1/branches/main/endpoints/ep1",
}, result)
}

func TestResolvePostgresValuesSkipsDatabaseListWhenNameProvided(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient)

branchName := "projects/p1/branches/main"
dbName := "projects/p1/branches/main/databases/mydb"

endpoints := listing.SliceIterator[postgres.Endpoint]{
{
Name: "projects/p1/branches/main/endpoints/ep1",
Status: &postgres.EndpointStatus{
EndpointType: postgres.EndpointTypeEndpointTypeReadWrite,
Hosts: &postgres.EndpointHosts{Host: "my-host.example.com"},
},
},
}
m.GetMockPostgresAPI().EXPECT().
ListEndpoints(mock.Anything, postgres.ListEndpointsRequest{Parent: branchName}).
Return(&endpoints).Once()

// ListDatabases should NOT be called since pgDatabaseName is pre-provided.

r := newPostgresResource()
result, err := ResolvePostgresValues(ctx, r, branchName, dbName, "my_pg_db")
require.NoError(t, err)

assert.Equal(t, map[string]string{
"postgres.host": "my-host.example.com",
"postgres.databaseName": "my_pg_db",
"postgres.endpointPath": "projects/p1/branches/main/endpoints/ep1",
}, result)
}

func TestResolvePostgresValuesEndpointAPIError(t *testing.T) {
m := mocks.NewMockWorkspaceClient(t)
ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient)

branchName := "projects/p1/branches/main"

// Return an iterator that yields an error.
m.GetMockPostgresAPI().EXPECT().
ListEndpoints(mock.Anything, postgres.ListEndpointsRequest{Parent: branchName}).
RunAndReturn(func(_ context.Context, _ postgres.ListEndpointsRequest) listing.Iterator[postgres.Endpoint] {
return &errorIterator[postgres.Endpoint]{err: errors.New("API unavailable")}
}).Once()

r := newPostgresResource()
_, err := ResolvePostgresValues(ctx, r, branchName, "some-db", "")
require.Error(t, err)
assert.Contains(t, err.Error(), "resolving endpoint details")
}

// errorIterator is a test helper that always returns an error.
type errorIterator[T any] struct {
err error
}

func (e *errorIterator[T]) HasNext(_ context.Context) bool { return true }

func (e *errorIterator[T]) Next(_ context.Context) (T, error) {
var zero T
return zero, e.err
}
15 changes: 15 additions & 0 deletions libs/apps/prompt/resource_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@ const (
// keys use the format "resource_key.field_name" (e.g., {"database.instance_name": "x", "database.database_name": "y"}).
type PromptResourceFunc func(ctx context.Context, r manifest.Resource, required bool) (map[string]string, error)

// ResolveResourceFunc resolves derived field values for a resource.
// It receives the resource and already-known values (from prompts or flags),
// returning additional derived values to merge in.
type ResolveResourceFunc func(ctx context.Context, r manifest.Resource, provided map[string]string) (map[string]string, error)

// GetResolveFunc returns the resolve function for the given resource type, or (nil, false) if not needed.
func GetResolveFunc(resourceType string) (ResolveResourceFunc, bool) {
switch resourceType {
case ResourceTypePostgres:
return resolvePostgresResource, true
default:
return nil, false
}
}

// GetPromptFunc returns the prompt function for the given resource type, or (nil, false) if not supported.
func GetPromptFunc(resourceType string) (PromptResourceFunc, bool) {
switch resourceType {
Expand Down
Loading