diff --git a/ir/queries/queries.sql b/ir/queries/queries.sql index 062bc135..28dbac9d 100644 --- a/ir/queries/queries.sql +++ b/ir/queries/queries.sql @@ -972,12 +972,18 @@ LEFT JOIN pg_class dep_table ON d.refobjid = dep_table.oid LEFT JOIN pg_attribute dep_col ON dep_col.attrelid = dep_table.oid AND dep_col.attnum = d.refobjsubid -- Method 2: Find sequences used in column defaults (for nextval() patterns) LEFT JOIN ( - SELECT + SELECT col.table_name, col.column_name, - REGEXP_REPLACE( - REGEXP_REPLACE(col.column_default, 'nextval\(''([^'']+)''.*\)', '\1'), - '^[^.]*\.', '' + REPLACE( + REGEXP_REPLACE( + REGEXP_REPLACE( + REGEXP_REPLACE(col.column_default, 'nextval\(''([^'']+)''.*\)', '\1'), + '^("([^"]|"")*"\.|[^.]*\.)', '' + ), + '^"(.*)"$', '\1' + ), + '""', '"' ) AS sequence_name FROM information_schema.columns col WHERE col.table_schema = $1 diff --git a/ir/queries/queries.sql.go b/ir/queries/queries.sql.go index 0a5e4848..47015ead 100644 --- a/ir/queries/queries.sql.go +++ b/ir/queries/queries.sql.go @@ -2732,12 +2732,18 @@ LEFT JOIN pg_depend d ON d.objid = c.oid AND d.classid = 'pg_class'::regclass AN LEFT JOIN pg_class dep_table ON d.refobjid = dep_table.oid LEFT JOIN pg_attribute dep_col ON dep_col.attrelid = dep_table.oid AND dep_col.attnum = d.refobjsubid LEFT JOIN ( - SELECT + SELECT col.table_name, col.column_name, - REGEXP_REPLACE( - REGEXP_REPLACE(col.column_default, 'nextval\(''([^'']+)''.*\)', '\1'), - '^[^.]*\.', '' + REPLACE( + REGEXP_REPLACE( + REGEXP_REPLACE( + REGEXP_REPLACE(col.column_default, 'nextval\(''([^'']+)''.*\)', '\1'), + '^("([^"]|"")*"\.|[^.]*\.)', '' + ), + '^"(.*)"$', '\1' + ), + '""', '"' ) AS sequence_name FROM information_schema.columns col WHERE col.table_schema = $1 diff --git a/ir/queries/queries_test.go b/ir/queries/queries_test.go new file mode 100644 index 00000000..4264c38f --- /dev/null +++ b/ir/queries/queries_test.go @@ -0,0 +1,71 @@ +package queries_test + +import ( + "context" + "database/sql" + "strings" + "testing" + + "github.com/pgplex/pgschema/internal/postgres" + "github.com/pgplex/pgschema/ir/queries" + "github.com/pgplex/pgschema/testutil" +) + +var sharedTestPostgres *postgres.EmbeddedPostgres + +func TestMain(m *testing.M) { + sharedTestPostgres = testutil.SetupPostgres(nil) + defer sharedTestPostgres.Stop() + m.Run() +} + +func TestGetSequencesForSchemaDetectsMixedCaseSequenceInColumnDefault(t *testing.T) { + conn, _, _, _, _, _ := testutil.ConnectToPostgres(t, sharedTestPostgres) + defer conn.Close() + + ctx := context.Background() + if _, err := conn.ExecContext(ctx, `DROP TABLE IF EXISTS orders CASCADE`); err != nil { + t.Fatalf("failed to drop test table: %v", err) + } + if _, err := conn.ExecContext(ctx, `CREATE TABLE orders ("orderId" SERIAL PRIMARY KEY)`); err != nil { + t.Fatalf("failed to create test table: %v", err) + } + // Drop the pg_depend ownership edge that SERIAL creates automatically. + // GetSequencesForSchema detects ownership via two paths: pg_depend (primary) + // and column_default parsing (fallback). Without this, pg_depend resolves + // ownership before the column_default regex is ever reached, so the test + // would pass even with the broken regex. OWNED BY NONE forces the fallback + // path — the one that was broken for mixed-case identifiers before this fix. + if _, err := conn.ExecContext(ctx, `ALTER SEQUENCE "orders_orderId_seq" OWNED BY NONE`); err != nil { + t.Fatalf("failed to remove sequence ownership dependency: %v", err) + } + + rows, err := queries.New(conn).GetSequencesForSchema(ctx, sql.NullString{String: "public", Valid: true}) + if err != nil { + t.Fatalf("failed to get sequences for schema: %v", err) + } + + for _, row := range rows { + if row.SequenceName.String != "orders_orderId_seq" { + continue + } + + if !row.OwnedByTable.Valid || row.OwnedByTable.String != "orders" { + t.Fatalf("OwnedByTable = %q, want %q", row.OwnedByTable.String, "orders") + } + if !row.OwnedByColumn.Valid || row.OwnedByColumn.String != "orderId" { + t.Fatalf("OwnedByColumn = %q, want %q", row.OwnedByColumn.String, "orderId") + } + return + } + + t.Fatalf("sequence %q not found; got sequences: %s", "orders_orderId_seq", sequenceNames(rows)) +} + +func sequenceNames(rows []queries.GetSequencesForSchemaRow) string { + names := make([]string, 0, len(rows)) + for _, row := range rows { + names = append(names, row.SequenceName.String) + } + return strings.Join(names, ", ") +}