@@ -8,66 +8,120 @@ import (
88 "strings"
99 "testing"
1010
11+ "github.com/sqlc-dev/sqlc/internal/config"
1112 "github.com/sqlc-dev/sqlc/internal/debug"
1213 "github.com/sqlc-dev/sqlc/internal/engine/postgresql"
1314 "github.com/sqlc-dev/sqlc/internal/sql/ast"
1415)
1516
1617func TestFormat (t * testing.T ) {
1718 t .Parallel ()
18- parse := postgresql .NewParser ()
1919 for _ , tc := range FindTests (t , "testdata" , "base" ) {
2020 tc := tc
21-
22- if ! strings .Contains (tc .Path , filepath .Join ("pgx/v5" )) {
23- continue
24- }
25-
26- q := filepath .Join (tc .Path , "query.sql" )
27- if _ , err := os .Stat (q ); os .IsNotExist (err ) {
28- continue
29- }
30-
3121 t .Run (tc .Name , func (t * testing.T ) {
32- contents , err := os .ReadFile (q )
22+ // Parse the config file to determine the engine
23+ configPath := filepath .Join (tc .Path , tc .ConfigName )
24+ configFile , err := os .Open (configPath )
3325 if err != nil {
3426 t .Fatal (err )
3527 }
36- for i , query := range bytes .Split (bytes .TrimSpace (contents ), []byte (";" )) {
37- if len (query ) <= 1 {
38- continue
39- }
40- query := query
41- t .Run (fmt .Sprintf ("%d" , i ), func (t * testing.T ) {
42- expected , err := postgresql .Fingerprint (string (query ))
43- if err != nil {
44- t .Fatal (err )
45- }
46- stmts , err := parse .Parse (bytes .NewReader (query ))
47- if err != nil {
48- t .Fatal (err )
49- }
50- if len (stmts ) != 1 {
51- t .Fatal ("expected one statement" )
52- }
53- if false {
54- r , err := postgresql .Parse (string (query ))
55- debug .Dump (r , err )
56- }
28+ conf , err := config .ParseConfig (configFile )
29+ configFile .Close ()
30+ if err != nil {
31+ t .Fatal (err )
32+ }
33+
34+ // Skip if there are no SQL packages configured
35+ if len (conf .SQL ) == 0 {
36+ return
37+ }
38+
39+ // For now, only test PostgreSQL since that's the only engine with Format support
40+ engine := conf .SQL [0 ].Engine
41+ if engine != config .EnginePostgreSQL {
42+ return
43+ }
5744
58- out := ast .Format (stmts [0 ].Raw )
59- actual , err := postgresql .Fingerprint (out )
45+ // Find query files from config
46+ var queryFiles []string
47+ for _ , sql := range conf .SQL {
48+ for _ , q := range sql .Queries {
49+ queryPath := filepath .Join (tc .Path , q )
50+ info , err := os .Stat (queryPath )
6051 if err != nil {
61- t . Error ( err )
52+ continue
6253 }
63- if expected != actual {
64- debug .Dump (stmts [0 ].Raw )
65- t .Errorf ("- %s" , expected )
66- t .Errorf ("- %s" , string (query ))
67- t .Errorf ("+ %s" , actual )
68- t .Errorf ("+ %s" , out )
54+ if info .IsDir () {
55+ // If it's a directory, glob for .sql files
56+ matches , err := filepath .Glob (filepath .Join (queryPath , "*.sql" ))
57+ if err != nil {
58+ continue
59+ }
60+ queryFiles = append (queryFiles , matches ... )
61+ } else {
62+ queryFiles = append (queryFiles , queryPath )
6963 }
70- })
64+ }
65+ }
66+
67+ if len (queryFiles ) == 0 {
68+ return
69+ }
70+
71+ parse := postgresql .NewParser ()
72+
73+ for _ , queryFile := range queryFiles {
74+ if _ , err := os .Stat (queryFile ); os .IsNotExist (err ) {
75+ continue
76+ }
77+
78+ contents , err := os .ReadFile (queryFile )
79+ if err != nil {
80+ t .Fatal (err )
81+ }
82+
83+ // Parse the entire file to get proper statement boundaries
84+ stmts , err := parse .Parse (bytes .NewReader (contents ))
85+ if err != nil {
86+ t .Fatal (err )
87+ }
88+
89+ for i , stmt := range stmts {
90+ stmt := stmt
91+ t .Run (fmt .Sprintf ("%d" , i ), func (t * testing.T ) {
92+ // Extract the original query text using statement location and length
93+ start := stmt .Raw .StmtLocation
94+ length := stmt .Raw .StmtLen
95+ if length == 0 {
96+ // If StmtLen is 0, it means the statement goes to the end of the input
97+ length = len (contents ) - start
98+ }
99+ query := strings .TrimSpace (string (contents [start : start + length ]))
100+
101+ expected , err := postgresql .Fingerprint (query )
102+ if err != nil {
103+ t .Fatal (err )
104+ }
105+
106+ if false {
107+ r , err := postgresql .Parse (query )
108+ debug .Dump (r , err )
109+ }
110+
111+ out := ast .Format (stmt .Raw )
112+ actual , err := postgresql .Fingerprint (out )
113+ if err != nil {
114+ t .Error (err )
115+ }
116+ if expected != actual {
117+ debug .Dump (stmt .Raw )
118+ t .Errorf ("- %s" , expected )
119+ t .Errorf ("- %s" , query )
120+ t .Errorf ("+ %s" , actual )
121+ t .Errorf ("+ %s" , out )
122+ }
123+ })
124+ }
71125 }
72126 })
73127 }
0 commit comments