Skip to content

Commit ab9ce9e

Browse files
kyleconroyclaude
andcommitted
refactor(fmt_test): use config-based engine detection and parser for statement boundaries
- Parse sqlc config file to determine database engine instead of hardcoding pgx/v5 path filter - Use parser's StmtLocation/StmtLen for proper statement boundaries instead of naive semicolon splitting - Handle both file and directory paths in queries config - Only test PostgreSQL for now (formatting support is PostgreSQL-only) This fixes issues with multi-query files containing semicolons in strings, PL/pgSQL functions, or DO blocks. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 3436b28 commit ab9ce9e

File tree

1 file changed

+97
-43
lines changed

1 file changed

+97
-43
lines changed

internal/endtoend/fmt_test.go

Lines changed: 97 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,66 +8,120 @@ import (
88
"strings"
99
"testing"
1010

11+
"github.com/sqlc-dev/sqlc/internal/config"
1112
"github.com/sqlc-dev/sqlc/internal/debug"
1213
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
1314
"github.com/sqlc-dev/sqlc/internal/sql/ast"
1415
)
1516

1617
func TestFormat(t *testing.T) {
1718
t.Parallel()
18-
parse := postgresql.NewParser()
1919
for _, tc := range FindTests(t, "testdata", "base") {
2020
tc := tc
21-
22-
if !strings.Contains(tc.Path, filepath.Join("pgx/v5")) {
23-
continue
24-
}
25-
26-
q := filepath.Join(tc.Path, "query.sql")
27-
if _, err := os.Stat(q); os.IsNotExist(err) {
28-
continue
29-
}
30-
3121
t.Run(tc.Name, func(t *testing.T) {
32-
contents, err := os.ReadFile(q)
22+
// Parse the config file to determine the engine
23+
configPath := filepath.Join(tc.Path, tc.ConfigName)
24+
configFile, err := os.Open(configPath)
3325
if err != nil {
3426
t.Fatal(err)
3527
}
36-
for i, query := range bytes.Split(bytes.TrimSpace(contents), []byte(";")) {
37-
if len(query) <= 1 {
38-
continue
39-
}
40-
query := query
41-
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
42-
expected, err := postgresql.Fingerprint(string(query))
43-
if err != nil {
44-
t.Fatal(err)
45-
}
46-
stmts, err := parse.Parse(bytes.NewReader(query))
47-
if err != nil {
48-
t.Fatal(err)
49-
}
50-
if len(stmts) != 1 {
51-
t.Fatal("expected one statement")
52-
}
53-
if false {
54-
r, err := postgresql.Parse(string(query))
55-
debug.Dump(r, err)
56-
}
28+
conf, err := config.ParseConfig(configFile)
29+
configFile.Close()
30+
if err != nil {
31+
t.Fatal(err)
32+
}
33+
34+
// Skip if there are no SQL packages configured
35+
if len(conf.SQL) == 0 {
36+
return
37+
}
38+
39+
// For now, only test PostgreSQL since that's the only engine with Format support
40+
engine := conf.SQL[0].Engine
41+
if engine != config.EnginePostgreSQL {
42+
return
43+
}
5744

58-
out := ast.Format(stmts[0].Raw)
59-
actual, err := postgresql.Fingerprint(out)
45+
// Find query files from config
46+
var queryFiles []string
47+
for _, sql := range conf.SQL {
48+
for _, q := range sql.Queries {
49+
queryPath := filepath.Join(tc.Path, q)
50+
info, err := os.Stat(queryPath)
6051
if err != nil {
61-
t.Error(err)
52+
continue
6253
}
63-
if expected != actual {
64-
debug.Dump(stmts[0].Raw)
65-
t.Errorf("- %s", expected)
66-
t.Errorf("- %s", string(query))
67-
t.Errorf("+ %s", actual)
68-
t.Errorf("+ %s", out)
54+
if info.IsDir() {
55+
// If it's a directory, glob for .sql files
56+
matches, err := filepath.Glob(filepath.Join(queryPath, "*.sql"))
57+
if err != nil {
58+
continue
59+
}
60+
queryFiles = append(queryFiles, matches...)
61+
} else {
62+
queryFiles = append(queryFiles, queryPath)
6963
}
70-
})
64+
}
65+
}
66+
67+
if len(queryFiles) == 0 {
68+
return
69+
}
70+
71+
parse := postgresql.NewParser()
72+
73+
for _, queryFile := range queryFiles {
74+
if _, err := os.Stat(queryFile); os.IsNotExist(err) {
75+
continue
76+
}
77+
78+
contents, err := os.ReadFile(queryFile)
79+
if err != nil {
80+
t.Fatal(err)
81+
}
82+
83+
// Parse the entire file to get proper statement boundaries
84+
stmts, err := parse.Parse(bytes.NewReader(contents))
85+
if err != nil {
86+
t.Fatal(err)
87+
}
88+
89+
for i, stmt := range stmts {
90+
stmt := stmt
91+
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
92+
// Extract the original query text using statement location and length
93+
start := stmt.Raw.StmtLocation
94+
length := stmt.Raw.StmtLen
95+
if length == 0 {
96+
// If StmtLen is 0, it means the statement goes to the end of the input
97+
length = len(contents) - start
98+
}
99+
query := strings.TrimSpace(string(contents[start : start+length]))
100+
101+
expected, err := postgresql.Fingerprint(query)
102+
if err != nil {
103+
t.Fatal(err)
104+
}
105+
106+
if false {
107+
r, err := postgresql.Parse(query)
108+
debug.Dump(r, err)
109+
}
110+
111+
out := ast.Format(stmt.Raw)
112+
actual, err := postgresql.Fingerprint(out)
113+
if err != nil {
114+
t.Error(err)
115+
}
116+
if expected != actual {
117+
debug.Dump(stmt.Raw)
118+
t.Errorf("- %s", expected)
119+
t.Errorf("- %s", query)
120+
t.Errorf("+ %s", actual)
121+
t.Errorf("+ %s", out)
122+
}
123+
})
124+
}
71125
}
72126
})
73127
}

0 commit comments

Comments
 (0)