diff --git a/cmd/apps/init.go b/cmd/apps/init.go index a94e2c9fb2..adbc2866ba 100644 --- a/cmd/apps/init.go +++ b/cmd/apps/init.go @@ -717,9 +717,29 @@ func runCreate(ctx context.Context, opts createOptions) error { // Always include mandatory plugins regardless of user selection or flags. selectedPlugins = appendUnique(selectedPlugins, m.GetMandatoryPluginNames()...) - // In flags/non-interactive mode, validate that all required resources are provided. + // In flags/non-interactive mode, resolve derived postgres values and validate resources. if flagsMode || !isInteractive { resources := m.CollectResources(selectedPlugins) + + // Resolve derived values for resources that support it. + for _, r := range resources { + resolveFn, ok := prompt.GetResolveFunc(r.Type) + if !ok { + continue + } + resolved, err := resolveFn(ctx, r, resourceValues) + if err != nil { + log.Warnf(ctx, "Could not resolve derived values for %s: %v", r.Alias, err) + continue + } + for k, v := range resolved { + if resourceValues[k] == "" { + resourceValues[k] = v + } + } + } + + // Validate that all required resources are provided. for _, r := range resources { found := false for k := range resourceValues { diff --git a/libs/apps/prompt/prompt.go b/libs/apps/prompt/prompt.go index d24d3a8715..fe661b8b68 100644 --- a/libs/apps/prompt/prompt.go +++ b/libs/apps/prompt/prompt.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "maps" "os" "path/filepath" "regexp" @@ -564,29 +565,6 @@ func PromptForPostgres(ctx context.Context, r manifest.Resource, required bool) return nil, nil } - // Step 2.5: resolve endpoint details from the branch (non-fatal). - var host, endpointPath string - endpointErr := RunWithSpinnerCtx(ctx, "Resolving connection details...", func() error { - endpoints, fetchErr := ListPostgresEndpoints(ctx, branchName) - if fetchErr != nil { - return fetchErr - } - for _, ep := range endpoints { - if ep.Status != nil && ep.Status.EndpointType == postgres.EndpointTypeEndpointTypeReadWrite { - endpointPath = ep.Name - if ep.Status.Hosts != nil && ep.Status.Hosts.Host != "" { - host = ep.Status.Hosts.Host - } - break - } - } - return nil - }) - if endpointErr != nil { - log.Warnf(ctx, "Could not resolve endpoint details: %v", endpointErr) - // non-fatal: user can fill values manually - } - // Step 3: pick a database within the branch var databases []ListItem err = RunWithSpinnerCtx(ctx, "Fetching databases...", func() error { @@ -605,22 +583,72 @@ func PromptForPostgres(ctx context.Context, r manifest.Resource, required bool) return nil, nil } - // Build resolver results map keyed by resolver name. - resolvedValues := map[string]string{ - "postgres:host": host, - "postgres:databaseName": pgDatabaseName, - "postgres:endpointPath": endpointPath, - } - // Start with prompted values (fields without resolve). result := map[string]string{ r.Key() + ".branch": branchName, r.Key() + ".database": dbName, } - // Map resolved values to fields using the manifest's resolve property. - applyResolvedValues(r, resolvedValues, result) + // Resolve derived values (host, databaseName, endpointPath) — non-fatal. + resolved, resolveErr := ResolvePostgresValues(ctx, r, branchName, dbName, pgDatabaseName) + if resolveErr != nil { + log.Warnf(ctx, "Could not resolve connection details: %v", resolveErr) + } + maps.Copy(result, resolved) + + return result, nil +} + +// resolvePostgresResource adapts ResolvePostgresValues for the generic ResolveResourceFunc signature. +func resolvePostgresResource(ctx context.Context, r manifest.Resource, provided map[string]string) (map[string]string, error) { + branchName := provided[r.Key()+".branch"] + dbName := provided[r.Key()+".database"] + if branchName == "" || dbName == "" { + return nil, nil + } + return ResolvePostgresValues(ctx, r, branchName, dbName, "") +} + +// ResolvePostgresValues resolves derived field values (host, databaseName, endpointPath) +// from a branch and database resource name. If pgDatabaseName is already known +// (e.g. from a prior prompt), pass it to skip the ListDatabases API call. +func ResolvePostgresValues(ctx context.Context, r manifest.Resource, branchName, dbName, pgDatabaseName string) (map[string]string, error) { + var host, endpointPath string + endpoints, err := ListPostgresEndpoints(ctx, branchName) + if err != nil { + return nil, fmt.Errorf("resolving endpoint details: %w", err) + } + for _, ep := range endpoints { + if ep.Status != nil && ep.Status.EndpointType == postgres.EndpointTypeEndpointTypeReadWrite { + endpointPath = ep.Name + if ep.Status.Hosts != nil && ep.Status.Hosts.Host != "" { + host = ep.Status.Hosts.Host + } + break + } + } + + if pgDatabaseName == "" { + databases, err := ListPostgresDatabases(ctx, branchName) + if err != nil { + return nil, fmt.Errorf("resolving database name: %w", err) + } + for _, db := range databases { + if db.ID == dbName { + pgDatabaseName = db.Label + break + } + } + } + resolvedValues := map[string]string{ + "postgres:host": host, + "postgres:databaseName": pgDatabaseName, + "postgres:endpointPath": endpointPath, + } + + result := make(map[string]string) + applyResolvedValues(r, resolvedValues, result) return result, nil } diff --git a/libs/apps/prompt/resolve_postgres_test.go b/libs/apps/prompt/resolve_postgres_test.go new file mode 100644 index 0000000000..49e585049c --- /dev/null +++ b/libs/apps/prompt/resolve_postgres_test.go @@ -0,0 +1,222 @@ +package prompt + +import ( + "context" + "errors" + "testing" + + "github.com/databricks/cli/libs/apps/manifest" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/databricks-sdk-go/experimental/mocks" + "github.com/databricks/databricks-sdk-go/listing" + "github.com/databricks/databricks-sdk-go/service/postgres" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func newPostgresResource() manifest.Resource { + return manifest.Resource{ + ResourceKey: "postgres", + Fields: map[string]manifest.ResourceField{ + "branch": {Description: "branch path"}, + "database": {Description: "database name"}, + "host": {Resolve: "postgres:host"}, + "databaseName": {Resolve: "postgres:databaseName"}, + "endpointPath": {Resolve: "postgres:endpointPath"}, + }, + } +} + +func TestResolvePostgresValuesHappyPath(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) + + branchName := "projects/p1/branches/main" + dbName := "projects/p1/branches/main/databases/mydb" + + // Mock ListEndpoints + endpoints := listing.SliceIterator[postgres.Endpoint]{ + { + Name: "projects/p1/branches/main/endpoints/ep1", + Status: &postgres.EndpointStatus{ + EndpointType: postgres.EndpointTypeEndpointTypeReadWrite, + Hosts: &postgres.EndpointHosts{Host: "my-host.example.com"}, + }, + }, + } + m.GetMockPostgresAPI().EXPECT(). + ListEndpoints(mock.Anything, postgres.ListEndpointsRequest{Parent: branchName}). + Return(&endpoints).Once() + + // Mock ListDatabases + databases := listing.SliceIterator[postgres.Database]{ + { + Name: dbName, + Status: &postgres.DatabaseDatabaseStatus{PostgresDatabase: "my_pg_db"}, + }, + } + m.GetMockPostgresAPI().EXPECT(). + ListDatabases(mock.Anything, postgres.ListDatabasesRequest{Parent: branchName}). + Return(&databases).Once() + + r := newPostgresResource() + result, err := ResolvePostgresValues(ctx, r, branchName, dbName, "") + require.NoError(t, err) + + assert.Equal(t, map[string]string{ + "postgres.host": "my-host.example.com", + "postgres.databaseName": "my_pg_db", + "postgres.endpointPath": "projects/p1/branches/main/endpoints/ep1", + }, result) +} + +func TestResolvePostgresValuesNoReadWriteEndpoint(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) + + branchName := "projects/p1/branches/main" + dbName := "projects/p1/branches/main/databases/mydb" + + // Return only a read-only endpoint. + endpoints := listing.SliceIterator[postgres.Endpoint]{ + { + Name: "projects/p1/branches/main/endpoints/ep1", + Status: &postgres.EndpointStatus{ + EndpointType: postgres.EndpointTypeEndpointTypeReadOnly, + }, + }, + } + m.GetMockPostgresAPI().EXPECT(). + ListEndpoints(mock.Anything, postgres.ListEndpointsRequest{Parent: branchName}). + Return(&endpoints).Once() + + databases := listing.SliceIterator[postgres.Database]{ + { + Name: dbName, + Status: &postgres.DatabaseDatabaseStatus{PostgresDatabase: "my_pg_db"}, + }, + } + m.GetMockPostgresAPI().EXPECT(). + ListDatabases(mock.Anything, postgres.ListDatabasesRequest{Parent: branchName}). + Return(&databases).Once() + + r := newPostgresResource() + result, err := ResolvePostgresValues(ctx, r, branchName, dbName, "") + require.NoError(t, err) + + // host and endpointPath should be empty since no ReadWrite endpoint found. + assert.Equal(t, map[string]string{ + "postgres.host": "", + "postgres.databaseName": "my_pg_db", + "postgres.endpointPath": "", + }, result) +} + +func TestResolvePostgresValuesDatabaseNotFound(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) + + branchName := "projects/p1/branches/main" + dbName := "projects/p1/branches/main/databases/nonexistent" + + endpoints := listing.SliceIterator[postgres.Endpoint]{ + { + Name: "projects/p1/branches/main/endpoints/ep1", + Status: &postgres.EndpointStatus{ + EndpointType: postgres.EndpointTypeEndpointTypeReadWrite, + Hosts: &postgres.EndpointHosts{Host: "my-host.example.com"}, + }, + }, + } + m.GetMockPostgresAPI().EXPECT(). + ListEndpoints(mock.Anything, postgres.ListEndpointsRequest{Parent: branchName}). + Return(&endpoints).Once() + + // Return databases that don't match dbName. + databases := listing.SliceIterator[postgres.Database]{ + { + Name: "projects/p1/branches/main/databases/other", + Status: &postgres.DatabaseDatabaseStatus{PostgresDatabase: "other_db"}, + }, + } + m.GetMockPostgresAPI().EXPECT(). + ListDatabases(mock.Anything, postgres.ListDatabasesRequest{Parent: branchName}). + Return(&databases).Once() + + r := newPostgresResource() + result, err := ResolvePostgresValues(ctx, r, branchName, dbName, "") + require.NoError(t, err) + + // databaseName should be empty since no match. + assert.Equal(t, map[string]string{ + "postgres.host": "my-host.example.com", + "postgres.databaseName": "", + "postgres.endpointPath": "projects/p1/branches/main/endpoints/ep1", + }, result) +} + +func TestResolvePostgresValuesSkipsDatabaseListWhenNameProvided(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) + + branchName := "projects/p1/branches/main" + dbName := "projects/p1/branches/main/databases/mydb" + + endpoints := listing.SliceIterator[postgres.Endpoint]{ + { + Name: "projects/p1/branches/main/endpoints/ep1", + Status: &postgres.EndpointStatus{ + EndpointType: postgres.EndpointTypeEndpointTypeReadWrite, + Hosts: &postgres.EndpointHosts{Host: "my-host.example.com"}, + }, + }, + } + m.GetMockPostgresAPI().EXPECT(). + ListEndpoints(mock.Anything, postgres.ListEndpointsRequest{Parent: branchName}). + Return(&endpoints).Once() + + // ListDatabases should NOT be called since pgDatabaseName is pre-provided. + + r := newPostgresResource() + result, err := ResolvePostgresValues(ctx, r, branchName, dbName, "my_pg_db") + require.NoError(t, err) + + assert.Equal(t, map[string]string{ + "postgres.host": "my-host.example.com", + "postgres.databaseName": "my_pg_db", + "postgres.endpointPath": "projects/p1/branches/main/endpoints/ep1", + }, result) +} + +func TestResolvePostgresValuesEndpointAPIError(t *testing.T) { + m := mocks.NewMockWorkspaceClient(t) + ctx := cmdctx.SetWorkspaceClient(cmdio.MockDiscard(t.Context()), m.WorkspaceClient) + + branchName := "projects/p1/branches/main" + + // Return an iterator that yields an error. + m.GetMockPostgresAPI().EXPECT(). + ListEndpoints(mock.Anything, postgres.ListEndpointsRequest{Parent: branchName}). + RunAndReturn(func(_ context.Context, _ postgres.ListEndpointsRequest) listing.Iterator[postgres.Endpoint] { + return &errorIterator[postgres.Endpoint]{err: errors.New("API unavailable")} + }).Once() + + r := newPostgresResource() + _, err := ResolvePostgresValues(ctx, r, branchName, "some-db", "") + require.Error(t, err) + assert.Contains(t, err.Error(), "resolving endpoint details") +} + +// errorIterator is a test helper that always returns an error. +type errorIterator[T any] struct { + err error +} + +func (e *errorIterator[T]) HasNext(_ context.Context) bool { return true } + +func (e *errorIterator[T]) Next(_ context.Context) (T, error) { + var zero T + return zero, e.err +} diff --git a/libs/apps/prompt/resource_registry.go b/libs/apps/prompt/resource_registry.go index 92b6331c62..e6846ad2f2 100644 --- a/libs/apps/prompt/resource_registry.go +++ b/libs/apps/prompt/resource_registry.go @@ -30,6 +30,21 @@ const ( // keys use the format "resource_key.field_name" (e.g., {"database.instance_name": "x", "database.database_name": "y"}). type PromptResourceFunc func(ctx context.Context, r manifest.Resource, required bool) (map[string]string, error) +// ResolveResourceFunc resolves derived field values for a resource. +// It receives the resource and already-known values (from prompts or flags), +// returning additional derived values to merge in. +type ResolveResourceFunc func(ctx context.Context, r manifest.Resource, provided map[string]string) (map[string]string, error) + +// GetResolveFunc returns the resolve function for the given resource type, or (nil, false) if not needed. +func GetResolveFunc(resourceType string) (ResolveResourceFunc, bool) { + switch resourceType { + case ResourceTypePostgres: + return resolvePostgresResource, true + default: + return nil, false + } +} + // GetPromptFunc returns the prompt function for the given resource type, or (nil, false) if not supported. func GetPromptFunc(resourceType string) (PromptResourceFunc, bool) { switch resourceType {