From f15a8ee22f21635b5291ed1c07d17bce274218c1 Mon Sep 17 00:00:00 2001 From: Guillaume Date: Sat, 23 May 2026 22:56:46 +0200 Subject: [PATCH 1/2] Add support for extensions in the IR (#436) Extensions installed in the database are now read into the IR via a new `Extension` entity and emitted by the diff/dump path so dumps remain replayable on a fresh cluster. Without this, schemas that rely on extension-provided types or operator classes (e.g., a GIST index on UUID via `btree_gist`) couldn't be reproduced from `pgschema dump` output. Implementation: - `ir.Extension` lives at the IR root (cluster-level, not per-schema) - `Inspector.buildExtensions` queries `pg_extension` via a new sqlc query, excluding the always-present `plpgsql` - `internal/diff/extension.go` emits `CREATE EXTENSION IF NOT EXISTS` at the head of the create-phase output and `DROP EXTENSION IF EXISTS` at the tail of the drop-phase output - `WITH SCHEMA` is intentionally omitted from emission. Preserving the user-declared install schema requires more care than this PR takes on (`pg_extension` shows the *actual* install schema, which becomes pgschema's temporary schema during plan generation). Tracked as a follow-up Note on tests: extensions are cluster-level state and persisted across test cases on the shared embedded postgres instance. Without cleanup between cases, any setup.sql that installs an extension would leak into later tests' inspections. `TestPlanAndApply` now registers a `t.Cleanup` that drops non-default extensions when each case exits. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/migrate_integration_test.go | 45 ++++++++++++++ internal/diff/diff.go | 37 ++++++++++++ internal/diff/extension.go | 60 +++++++++++++++++++ ir/inspector.go | 25 ++++++++ ir/ir.go | 39 ++++++++++-- ir/queries/queries.sql | 17 +++++- ir/queries/queries.sql.go | 51 ++++++++++++++++ .../create_extension/add_extension/diff.sql | 1 + .../create_extension/add_extension/new.sql | 1 + .../create_extension/add_extension/old.sql | 1 + .../create_extension/add_extension/plan.json | 20 +++++++ .../create_extension/add_extension/plan.sql | 1 + .../create_extension/add_extension/plan.txt | 8 +++ 13 files changed, 301 insertions(+), 5 deletions(-) create mode 100644 internal/diff/extension.go create mode 100644 testdata/diff/create_extension/add_extension/diff.sql create mode 100644 testdata/diff/create_extension/add_extension/new.sql create mode 100644 testdata/diff/create_extension/add_extension/old.sql create mode 100644 testdata/diff/create_extension/add_extension/plan.json create mode 100644 testdata/diff/create_extension/add_extension/plan.sql create mode 100644 testdata/diff/create_extension/add_extension/plan.txt diff --git a/cmd/migrate_integration_test.go b/cmd/migrate_integration_test.go index e0dcfd10..869ee7ae 100644 --- a/cmd/migrate_integration_test.go +++ b/cmd/migrate_integration_test.go @@ -247,6 +247,14 @@ func runPlanAndApplyTest(t *testing.T, ctx context.Context, container *struct { if _, err := embeddedConn.ExecContext(ctx, string(setupContent)); err != nil { t.Fatalf("Failed to execute setup.sql to embedded postgres: %v", err) } + + // Extensions are cluster-level and persist across tests on the shared + // embedded postgres instance. Reset any non-default extensions when + // this test case finishes so the next case (or the next top-level test) + // inherits a clean cluster. + t.Cleanup(func() { + cleanupSharedClusterExtensions(t) + }) } } @@ -576,3 +584,40 @@ func matchesFilter(relPath, filter string) bool { // Fallback: check if filter is a substring of the path return strings.Contains(relPath, filter) } + +// cleanupSharedClusterExtensions drops any extensions on the shared embedded +// postgres instance other than the built-in `plpgsql`. Extensions are +// cluster-level state and survive per-database resets, so without this teardown +// a setup.sql that installs (say) btree_gist or hstore would leak into every +// subsequent test that inspects the cluster. +func cleanupSharedClusterExtensions(t *testing.T) { + t.Helper() + if sharedEmbeddedPG == nil { + return + } + conn, _, _, _, _, _ := testutil.ConnectToPostgres(t, sharedEmbeddedPG) + defer conn.Close() + + ctx := context.Background() + rows, err := conn.QueryContext(ctx, "SELECT extname FROM pg_extension WHERE extname <> 'plpgsql'") + if err != nil { + t.Logf("extension cleanup: query failed (continuing): %v", err) + return + } + var names []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + t.Logf("extension cleanup: scan failed (continuing): %v", err) + continue + } + names = append(names, name) + } + rows.Close() + + for _, name := range names { + if _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP EXTENSION IF EXISTS %q CASCADE", name)); err != nil { + t.Logf("extension cleanup: failed to drop %s (continuing): %v", name, err) + } + } +} diff --git a/internal/diff/diff.go b/internal/diff/diff.go index 7870cba6..df2572dc 100644 --- a/internal/diff/diff.go +++ b/internal/diff/diff.go @@ -42,6 +42,7 @@ const ( DiffTypePrivilege DiffTypeRevokedDefaultPrivilege DiffTypeColumnPrivilege + DiffTypeExtension ) // String returns the string representation of DiffType @@ -103,6 +104,8 @@ func (d DiffType) String() string { return "revoked_default_privilege" case DiffTypeColumnPrivilege: return "column_privilege" + case DiffTypeExtension: + return "extension" default: return "unknown" } @@ -177,6 +180,8 @@ func (d *DiffType) UnmarshalJSON(data []byte) error { *d = DiffTypeRevokedDefaultPrivilege case "column_privilege": *d = DiffTypeColumnPrivilege + case "extension": + *d = DiffTypeExtension default: return fmt.Errorf("unknown diff type: %s", s) } @@ -296,6 +301,9 @@ type ddlDiff struct { addedColumnPrivileges []*ir.ColumnPrivilege droppedColumnPrivileges []*ir.ColumnPrivilege modifiedColumnPrivileges []*columnPrivilegeDiff + // Cluster-level extensions + addedExtensions []*ir.Extension + droppedExtensions []*ir.Extension } // schemaDiff represents changes to a schema @@ -460,6 +468,27 @@ func GenerateMigration(oldIR, newIR *ir.IR, targetSchema string) []Diff { addedColumnPrivileges: []*ir.ColumnPrivilege{}, droppedColumnPrivileges: []*ir.ColumnPrivilege{}, modifiedColumnPrivileges: []*columnPrivilegeDiff{}, + addedExtensions: []*ir.Extension{}, + droppedExtensions: []*ir.Extension{}, + } + + // Compute extension diffs (cluster-level, so no schema filtering). + // Modifications (version bumps) are out of scope for this initial PR; only + // added/dropped are tracked. See #436 for the broader extension story. + { + extNames := sortedKeys(newIR.Extensions) + for _, name := range extNames { + newExt := newIR.Extensions[name] + if _, exists := oldIR.Extensions[name]; !exists { + diff.addedExtensions = append(diff.addedExtensions, newExt) + } + } + oldExtNames := sortedKeys(oldIR.Extensions) + for _, name := range oldExtNames { + if _, exists := newIR.Extensions[name]; !exists { + diff.droppedExtensions = append(diff.droppedExtensions, oldIR.Extensions[name]) + } + } } // Compare schemas first in deterministic order @@ -1499,6 +1528,10 @@ func (d *ddlDiff) generatePreDropMaterializedViewsSQL(targetSchema string, colle func (d *ddlDiff) generateCreateSQL(targetSchema string, collector *diffCollector) { // Note: Schema creation is out of scope for schema-level comparisons + // Extensions first: they provide operator classes, types, and functions that + // downstream schema objects (e.g., a GIST index on UUID via btree_gist) depend on. + generateCreateExtensionsSQL(d.addedExtensions, collector) + // Build function lookup early - needed for both domain and table dependency checks newFunctionLookup := buildFunctionLookup(d.addedFunctions) @@ -1721,6 +1754,10 @@ func (d *ddlDiff) generateDropSQL(targetSchema string, collector *diffCollector, // Drop types generateDropTypesSQL(d.droppedTypes, targetSchema, collector) + // Drop extensions last: any schema object that depended on the extension + // must already be gone before we try to drop the extension itself. + generateDropExtensionsSQL(d.droppedExtensions, collector) + // Drop schemas // Note: Schema deletion is out of scope for schema-level comparisons } diff --git a/internal/diff/extension.go b/internal/diff/extension.go new file mode 100644 index 00000000..b8e223df --- /dev/null +++ b/internal/diff/extension.go @@ -0,0 +1,60 @@ +package diff + +import ( + "fmt" + + "github.com/pgplex/pgschema/ir" +) + +// generateCreateExtensionsSQL generates `CREATE EXTENSION IF NOT EXISTS` statements +// for newly added extensions. Emitted before any schema-level objects because +// extensions can provide operator classes, types, and functions that those +// objects depend on (e.g., a GIST index using btree_gist's UUID operator class). +func generateCreateExtensionsSQL(extensions []*ir.Extension, collector *diffCollector) { + for _, ext := range extensions { + sql := generateExtensionSQL(ext) + context := &diffContext{ + Type: DiffTypeExtension, + Operation: DiffOperationCreate, + Path: extensionPath(ext), + Source: ext, + CanRunInTransaction: true, + } + collector.collect(context, sql) + } +} + +// generateDropExtensionsSQL generates `DROP EXTENSION IF EXISTS` statements +// for extensions removed from the target. Emitted after all schema-level drops +// to avoid dependency conflicts. +func generateDropExtensionsSQL(extensions []*ir.Extension, collector *diffCollector) { + for _, ext := range extensions { + context := &diffContext{ + Type: DiffTypeExtension, + Operation: DiffOperationDrop, + Path: extensionPath(ext), + Source: ext, + CanRunInTransaction: true, + } + collector.collect(context, fmt.Sprintf("DROP EXTENSION IF EXISTS %s;", ext.Name)) + } +} + +// extensionPath returns the identifier used in the diff Path field. Extensions +// are cluster-level so no schema qualifier is included; doing so would leak +// the plan command's temporary schema into the recorded plan and break +// golden-output stability across runs. +func extensionPath(ext *ir.Extension) string { + return ext.Name +} + +// generateExtensionSQL renders a single CREATE EXTENSION statement. +// Extensions are cluster-level; the installed schema is intentionally not +// emitted here. Honoring it would require either pinning it to the user's +// declared value (which we cannot recover from pg_extension alone — the plan +// command's temporary schema becomes the install schema when no WITH SCHEMA +// is given) or filtering out transient schemas. Preserving the user-declared +// install schema is tracked as a follow-up to #436. +func generateExtensionSQL(ext *ir.Extension) string { + return fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s;", ext.Name) +} diff --git a/ir/inspector.go b/ir/inspector.go index 23a0a5c4..d5a47e18 100644 --- a/ir/inspector.go +++ b/ir/inspector.go @@ -57,6 +57,10 @@ func (i *Inspector) BuildIR(ctx context.Context, targetSchema string) (*IR, erro return nil, fmt.Errorf("failed to build metadata: %w", err) } + if err := i.buildExtensions(ctx, schema); err != nil { + return nil, fmt.Errorf("failed to build extensions: %w", err) + } + if err := i.validateSchemaExists(ctx, targetSchema); err != nil { return nil, err } @@ -207,6 +211,27 @@ func (i *Inspector) buildMetadata(ctx context.Context, schema *IR) error { return nil } +// buildExtensions records every installed extension (except plpgsql) on the IR. +// Extensions are cluster-level — they are not scoped by targetSchema. +func (i *Inspector) buildExtensions(ctx context.Context, schema *IR) error { + rows, err := i.queries.GetExtensions(ctx) + if err != nil { + return err + } + for _, row := range rows { + if !row.ExtensionName.Valid { + continue + } + schema.SetExtension(&Extension{ + Name: row.ExtensionName.String, + Version: row.ExtensionVersion, + Schema: row.ExtensionSchema.String, + Comment: row.ExtensionComment.String, + }) + } + return nil +} + func (i *Inspector) buildSchemas(ctx context.Context, schema *IR, targetSchema string) error { // Use the schema-specific query to prefilter at the database level schemaName, err := i.queries.GetSchema(ctx, sql.NullString{String: targetSchema, Valid: true}) diff --git a/ir/ir.go b/ir/ir.go index fb0ff718..16d4ee48 100644 --- a/ir/ir.go +++ b/ir/ir.go @@ -8,9 +8,20 @@ import ( // IR represents the complete database schema intermediate representation type IR struct { - Metadata Metadata `json:"metadata"` - Schemas map[string]*Schema `json:"schemas"` // schema_name -> Schema - mu sync.RWMutex // Protects concurrent access to Schemas + Metadata Metadata `json:"metadata"` + Extensions map[string]*Extension `json:"extensions,omitempty"` // extension_name -> Extension (cluster-level, not per-schema) + Schemas map[string]*Schema `json:"schemas"` // schema_name -> Schema + mu sync.RWMutex // Protects concurrent access to Schemas and Extensions +} + +// Extension represents a PostgreSQL extension installed in the database. +// Extensions are cluster-level (installed once per database), so they live at +// the IR root rather than under a Schema. +type Extension struct { + Name string `json:"name"` // e.g., "btree_gist" + Version string `json:"version,omitempty"` // e.g., "1.7" + Schema string `json:"schema,omitempty"` // Namespace where the extension's default objects are installed + Comment string `json:"comment,omitempty"` } // Metadata contains information about the schema dump @@ -542,10 +553,29 @@ func (cp *ColumnPrivilege) GetObjectName() string { // NewIR creates a new empty catalog IR func NewIR() *IR { return &IR{ - Schemas: make(map[string]*Schema), + Schemas: make(map[string]*Schema), + Extensions: make(map[string]*Extension), } } +// SetExtension records an extension on the IR with thread safety. +func (c *IR) SetExtension(ext *Extension) { + c.mu.Lock() + defer c.mu.Unlock() + if c.Extensions == nil { + c.Extensions = make(map[string]*Extension) + } + c.Extensions[ext.Name] = ext +} + +// GetExtension retrieves an extension by name with thread safety. +func (c *IR) GetExtension(name string) (*Extension, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + ext, ok := c.Extensions[name] + return ext, ok +} + // GetSchema retrieves a schema by name with thread safety. // Returns the schema and true if found, or nil and false if not found. func (c *IR) GetSchema(name string) (*Schema, bool) { @@ -709,4 +739,5 @@ func (p *Procedure) GetObjectName() string { return p.Name } func (v *View) GetObjectName() string { return v.Name } func (s *Sequence) GetObjectName() string { return s.Name } func (t *Type) GetObjectName() string { return t.Name } +func (e *Extension) GetObjectName() string { return e.Name } diff --git a/ir/queries/queries.sql b/ir/queries/queries.sql index fab04e86..862aad7c 100644 --- a/ir/queries/queries.sql +++ b/ir/queries/queries.sql @@ -1446,4 +1446,19 @@ JOIN pg_namespace referenced_ns ON referenced_proc.pronamespace = referenced_ns. WHERE d.classid = 'pg_proc'::regclass AND d.refclassid = 'pg_proc'::regclass AND d.deptype = 'n' - AND dependent_ns.nspname = $1; \ No newline at end of file + AND dependent_ns.nspname = $1; + +-- GetExtensions retrieves all installed extensions except the always-present +-- `plpgsql` built-in. Used to render `CREATE EXTENSION` statements in the dump +-- so dumps remain replayable on a fresh database. +-- name: GetExtensions :many +SELECT + e.extname::text AS extension_name, + e.extversion::text AS extension_version, + n.nspname::text AS extension_schema, + COALESCE(d.description, '') AS extension_comment +FROM pg_extension e +JOIN pg_namespace n ON e.extnamespace = n.oid +LEFT JOIN pg_description d ON d.objoid = e.oid AND d.classoid = 'pg_extension'::regclass +WHERE e.extname != 'plpgsql' +ORDER BY e.extname; \ No newline at end of file diff --git a/ir/queries/queries.sql.go b/ir/queries/queries.sql.go index cbca828f..16c2ba45 100644 --- a/ir/queries/queries.sql.go +++ b/ir/queries/queries.sql.go @@ -1396,6 +1396,57 @@ func (q *Queries) GetEnumValuesForSchema(ctx context.Context, dollar_1 sql.NullS return items, nil } +const getExtensions = `-- name: GetExtensions :many +SELECT + e.extname::text AS extension_name, + e.extversion::text AS extension_version, + n.nspname::text AS extension_schema, + COALESCE(d.description, '') AS extension_comment +FROM pg_extension e +JOIN pg_namespace n ON e.extnamespace = n.oid +LEFT JOIN pg_description d ON d.objoid = e.oid AND d.classoid = 'pg_extension'::regclass +WHERE e.extname != 'plpgsql' +ORDER BY e.extname +` + +type GetExtensionsRow struct { + ExtensionName sql.NullString `db:"extension_name" json:"extension_name"` + ExtensionVersion string `db:"extension_version" json:"extension_version"` + ExtensionSchema sql.NullString `db:"extension_schema" json:"extension_schema"` + ExtensionComment sql.NullString `db:"extension_comment" json:"extension_comment"` +} + +// GetExtensions retrieves all installed extensions except the always-present +// `plpgsql` built-in. Used to render `CREATE EXTENSION` statements in the dump +// so dumps remain replayable on a fresh database. +func (q *Queries) GetExtensions(ctx context.Context) ([]GetExtensionsRow, error) { + rows, err := q.db.QueryContext(ctx, getExtensions) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetExtensionsRow + for rows.Next() { + var i GetExtensionsRow + if err := rows.Scan( + &i.ExtensionName, + &i.ExtensionVersion, + &i.ExtensionSchema, + &i.ExtensionComment, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getFunctionDependencies = `-- name: GetFunctionDependencies :many SELECT dependent_ns.nspname AS dependent_schema, diff --git a/testdata/diff/create_extension/add_extension/diff.sql b/testdata/diff/create_extension/add_extension/diff.sql new file mode 100644 index 00000000..c30efd12 --- /dev/null +++ b/testdata/diff/create_extension/add_extension/diff.sql @@ -0,0 +1 @@ +CREATE EXTENSION IF NOT EXISTS btree_gist; diff --git a/testdata/diff/create_extension/add_extension/new.sql b/testdata/diff/create_extension/add_extension/new.sql new file mode 100644 index 00000000..c30efd12 --- /dev/null +++ b/testdata/diff/create_extension/add_extension/new.sql @@ -0,0 +1 @@ +CREATE EXTENSION IF NOT EXISTS btree_gist; diff --git a/testdata/diff/create_extension/add_extension/old.sql b/testdata/diff/create_extension/add_extension/old.sql new file mode 100644 index 00000000..4ba9e740 --- /dev/null +++ b/testdata/diff/create_extension/add_extension/old.sql @@ -0,0 +1 @@ +-- Empty schema (no extensions declared) diff --git a/testdata/diff/create_extension/add_extension/plan.json b/testdata/diff/create_extension/add_extension/plan.json new file mode 100644 index 00000000..05defe30 --- /dev/null +++ b/testdata/diff/create_extension/add_extension/plan.json @@ -0,0 +1,20 @@ +{ + "version": "1.0.0", + "pgschema_version": "1.9.0", + "created_at": "1970-01-01T00:00:00Z", + "source_fingerprint": { + "hash": "965b1131737c955e24c7f827c55bd78e4cb49a75adfd04229e0ba297376f5085" + }, + "groups": [ + { + "steps": [ + { + "sql": "CREATE EXTENSION IF NOT EXISTS btree_gist;", + "type": "extension", + "operation": "create", + "path": "btree_gist" + } + ] + } + ] +} diff --git a/testdata/diff/create_extension/add_extension/plan.sql b/testdata/diff/create_extension/add_extension/plan.sql new file mode 100644 index 00000000..c30efd12 --- /dev/null +++ b/testdata/diff/create_extension/add_extension/plan.sql @@ -0,0 +1 @@ +CREATE EXTENSION IF NOT EXISTS btree_gist; diff --git a/testdata/diff/create_extension/add_extension/plan.txt b/testdata/diff/create_extension/add_extension/plan.txt new file mode 100644 index 00000000..f246e34c --- /dev/null +++ b/testdata/diff/create_extension/add_extension/plan.txt @@ -0,0 +1,8 @@ +Plan: 1 to add. + +Summary by type: + +DDL to be executed: +-------------------------------------------------- + +CREATE EXTENSION IF NOT EXISTS btree_gist; From 8c99b001118eb2ce3878e2d0ee9a63a22ae9a279 Mon Sep 17 00:00:00 2001 From: Guillaume Lecomte Date: Sat, 23 May 2026 23:32:05 +0200 Subject: [PATCH 2/2] fix: quote extension identifiers and broaden test cleanup Address Greptile review on #443: - CREATE/DROP EXTENSION now go through ir.QuoteIdentifier, so names like uuid-ossp render as a properly quoted identifier instead of invalid SQL. - The shared-cluster extension cleanup in runPlanAndApplyTest used to register only when setup.sql was non-empty. Extensions installed via new.sql (the create_extension fixture does exactly this) could leak across cases on the shared embedded postgres. Register the cleanup unconditionally. Co-Authored-By: Claude Opus 4.7 (1M context) --- cmd/migrate_integration_test.go | 16 ++++++++-------- internal/diff/extension.go | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/cmd/migrate_integration_test.go b/cmd/migrate_integration_test.go index 869ee7ae..96edad79 100644 --- a/cmd/migrate_integration_test.go +++ b/cmd/migrate_integration_test.go @@ -228,6 +228,14 @@ func runPlanAndApplyTest(t *testing.T, ctx context.Context, container *struct { t.Fatalf("Failed to create test database %s: %v", dbName, err) } + // Extensions are cluster-level and persist across tests on the shared + // embedded postgres instance. Register the teardown unconditionally — + // extensions can come from setup.sql, old.sql, or new.sql (the new.sql + // path matters for the create_extension fixture, which has no setup.sql). + t.Cleanup(func() { + cleanupSharedClusterExtensions(t) + }) + // STEP 0: Execute optional setup.sql (for cross-schema setup, extension types, etc.) if _, err := os.Stat(tc.setupFile); err == nil { setupContent, err := os.ReadFile(tc.setupFile) @@ -247,14 +255,6 @@ func runPlanAndApplyTest(t *testing.T, ctx context.Context, container *struct { if _, err := embeddedConn.ExecContext(ctx, string(setupContent)); err != nil { t.Fatalf("Failed to execute setup.sql to embedded postgres: %v", err) } - - // Extensions are cluster-level and persist across tests on the shared - // embedded postgres instance. Reset any non-default extensions when - // this test case finishes so the next case (or the next top-level test) - // inherits a clean cluster. - t.Cleanup(func() { - cleanupSharedClusterExtensions(t) - }) } } diff --git a/internal/diff/extension.go b/internal/diff/extension.go index b8e223df..cc365289 100644 --- a/internal/diff/extension.go +++ b/internal/diff/extension.go @@ -36,7 +36,7 @@ func generateDropExtensionsSQL(extensions []*ir.Extension, collector *diffCollec Source: ext, CanRunInTransaction: true, } - collector.collect(context, fmt.Sprintf("DROP EXTENSION IF EXISTS %s;", ext.Name)) + collector.collect(context, fmt.Sprintf("DROP EXTENSION IF EXISTS %s;", ir.QuoteIdentifier(ext.Name))) } } @@ -56,5 +56,5 @@ func extensionPath(ext *ir.Extension) string { // is given) or filtering out transient schemas. Preserving the user-declared // install schema is tracked as a follow-up to #436. func generateExtensionSQL(ext *ir.Extension) string { - return fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s;", ext.Name) + return fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s;", ir.QuoteIdentifier(ext.Name)) }