Skip to content

Commit e5dfa2b

Browse files
committed
fix: inferred column names
1 parent f758c7e commit e5dfa2b

File tree

3 files changed

+364
-60
lines changed

3 files changed

+364
-60
lines changed

internal/compiler/find_params.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
140140
p.parent = node
141141

142142
case *ast.SelectStmt:
143+
if n.FromClause != nil && len(n.FromClause.Items) == 1 {
144+
if rv, ok := n.FromClause.Items[0].(*ast.RangeVar); ok {
145+
p.rangeVar = rv
146+
}
147+
}
143148
if n.LimitCount != nil {
144149
p.limitCount = n.LimitCount
145150
}

internal/compiler/parse_test.go

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
package compiler
2+
3+
import (
4+
"context"
5+
"strings"
6+
"testing"
7+
8+
analysispb "github.com/sqlc-dev/sqlc/internal/analysis"
9+
"github.com/sqlc-dev/sqlc/internal/config"
10+
"github.com/sqlc-dev/sqlc/internal/engine/dolphin"
11+
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
12+
"github.com/sqlc-dev/sqlc/internal/opts"
13+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
14+
"github.com/sqlc-dev/sqlc/internal/sql/named"
15+
)
16+
17+
const batchParameterTypeSchema = `
18+
CREATE TABLE public.solar_commcard_mapping (
19+
id INT8 NOT NULL,
20+
"deviceId" INT8 NOT NULL,
21+
version VARCHAR(32) DEFAULT ''::VARCHAR NOT NULL,
22+
sn VARCHAR(32) DEFAULT ''::VARCHAR NOT NULL,
23+
"createdAt" TIMESTAMPTZ DEFAULT now(),
24+
"updatedAt" TIMESTAMPTZ DEFAULT now()
25+
);
26+
`
27+
28+
const batchParameterTypeQuery = `-- name: InsertMappping :batchexec
29+
WITH
30+
table1 AS (
31+
SELECT version
32+
FROM solar_commcard_mapping
33+
WHERE "deviceId" = $1
34+
ORDER BY "updatedAt" DESC
35+
LIMIT 1
36+
)
37+
INSERT INTO solar_commcard_mapping ("deviceId", version, sn, "updatedAt")
38+
SELECT $1, @version::text, $3, $4
39+
WHERE NOT EXISTS (
40+
SELECT *
41+
FROM table1
42+
WHERE table1.version = @version::text
43+
) OR NOT EXISTS (SELECT * FROM table1);
44+
`
45+
46+
const mysqlInListSchema = `
47+
CREATE TABLE foo (
48+
a VARCHAR(255) NOT NULL,
49+
b VARCHAR(255) NOT NULL
50+
);
51+
`
52+
53+
const mysqlInListQuery = `/* name: FooByList :many */
54+
SELECT a, b FROM foo WHERE foo.a IN (?, ?);
55+
`
56+
57+
type stubAnalyzer struct {
58+
analyze func(context.Context, ast.Node, string, []string, *named.ParamSet) (*analysispb.Analysis, error)
59+
}
60+
61+
func (s stubAnalyzer) Analyze(ctx context.Context, n ast.Node, q string, schema []string, np *named.ParamSet) (*analysispb.Analysis, error) {
62+
return s.analyze(ctx, n, q, schema, np)
63+
}
64+
65+
func (stubAnalyzer) Close(context.Context) error { return nil }
66+
func (stubAnalyzer) EnsureConn(context.Context, []string) error { return nil }
67+
func (stubAnalyzer) GetColumnNames(context.Context, string) ([]string, error) { return nil, nil }
68+
69+
func newBatchParameterTypeCompiler(t *testing.T) (*Compiler, *ast.RawStmt) {
70+
t.Helper()
71+
72+
parser := postgresql.NewParser()
73+
catalog := postgresql.NewCatalog()
74+
75+
schema, err := parser.Parse(strings.NewReader(batchParameterTypeSchema))
76+
if err != nil {
77+
t.Fatal(err)
78+
}
79+
if err := catalog.Build(schema); err != nil {
80+
t.Fatal(err)
81+
}
82+
83+
stmts, err := parser.Parse(strings.NewReader(batchParameterTypeQuery))
84+
if err != nil {
85+
t.Fatal(err)
86+
}
87+
if len(stmts) != 1 {
88+
t.Fatalf("expected 1 statement, got %d", len(stmts))
89+
}
90+
91+
return &Compiler{
92+
conf: config.SQL{Engine: config.EnginePostgreSQL},
93+
parser: parser,
94+
catalog: catalog,
95+
selector: newDefaultSelector(),
96+
}, stmts[0].Raw
97+
}
98+
99+
func newMySQLInListCompiler(t *testing.T) (*Compiler, *ast.RawStmt) {
100+
t.Helper()
101+
102+
parser := dolphin.NewParser()
103+
catalog := dolphin.NewCatalog()
104+
105+
schema, err := parser.Parse(strings.NewReader(mysqlInListSchema))
106+
if err != nil {
107+
t.Fatal(err)
108+
}
109+
if err := catalog.Build(schema); err != nil {
110+
t.Fatal(err)
111+
}
112+
113+
stmts, err := parser.Parse(strings.NewReader(mysqlInListQuery))
114+
if err != nil {
115+
t.Fatal(err)
116+
}
117+
if len(stmts) != 1 {
118+
t.Fatalf("expected 1 statement, got %d", len(stmts))
119+
}
120+
121+
return &Compiler{
122+
conf: config.SQL{Engine: config.EngineMySQL},
123+
parser: parser,
124+
catalog: catalog,
125+
selector: newDefaultSelector(),
126+
}, stmts[0].Raw
127+
}
128+
129+
func assertBatchParameterNames(t *testing.T, params []Parameter) {
130+
t.Helper()
131+
132+
checks := []struct {
133+
idx int
134+
number int
135+
name string
136+
original string
137+
named bool
138+
}{
139+
{idx: 0, number: 1, name: "deviceId", original: "deviceId"},
140+
{idx: 1, number: 2, name: "version", original: "version", named: true},
141+
{idx: 2, number: 3, name: "sn", original: "sn"},
142+
{idx: 3, number: 4, name: "updatedAt", original: "updatedAt"},
143+
}
144+
if len(params) != len(checks) {
145+
t.Fatalf("expected %d params, got %d", len(checks), len(params))
146+
}
147+
148+
for _, check := range checks {
149+
param := params[check.idx]
150+
if param.Number != check.number {
151+
t.Fatalf("param %d number mismatch: got %d want %d", check.idx, param.Number, check.number)
152+
}
153+
if param.Column == nil {
154+
t.Fatalf("param %d column is nil", check.idx)
155+
}
156+
if param.Column.Name != check.name {
157+
t.Fatalf("param %d name mismatch: got %q want %q", check.idx, param.Column.Name, check.name)
158+
}
159+
if param.Column.OriginalName != check.original {
160+
t.Fatalf("param %d original name mismatch: got %q want %q", check.idx, param.Column.OriginalName, check.original)
161+
}
162+
if param.Column.IsNamedParam != check.named {
163+
t.Fatalf("param %d named mismatch: got %v want %v", check.idx, param.Column.IsNamedParam, check.named)
164+
}
165+
if param.Column.DataType == "" || param.Column.DataType == "any" {
166+
t.Fatalf("param %d type was not inferred: %+v", check.idx, param.Column)
167+
}
168+
}
169+
}
170+
171+
func TestInferQueryPreservesInsertSelectParamNamesWithCTEAndMixedParams(t *testing.T) {
172+
t.Parallel()
173+
174+
comp, raw := newBatchParameterTypeCompiler(t)
175+
anlys, err := comp.inferQuery(raw, batchParameterTypeQuery)
176+
if err != nil && !strings.Contains(err.Error(), "parameter $2") {
177+
t.Fatalf("unexpected infer error: %v", err)
178+
}
179+
if anlys == nil {
180+
t.Fatal("expected non-nil analysis")
181+
}
182+
if !strings.Contains(anlys.Query, "$2::text") {
183+
t.Fatalf("expected rewritten query to contain $2::text, got %q", anlys.Query)
184+
}
185+
186+
assertBatchParameterNames(t, anlys.Parameters)
187+
}
188+
189+
func TestParseQueryManagedDBPreservesInferredParamNames(t *testing.T) {
190+
t.Parallel()
191+
192+
comp, raw := newBatchParameterTypeCompiler(t)
193+
comp.analyzer = stubAnalyzer{analyze: func(_ context.Context, _ ast.Node, query string, _ []string, np *named.ParamSet) (*analysispb.Analysis, error) {
194+
if np == nil {
195+
t.Fatal("expected named param set")
196+
}
197+
if got, ok := np.NameFor(2); !ok || got != "version" {
198+
t.Fatalf("expected param 2 to be named version, got %q %v", got, ok)
199+
}
200+
if !strings.Contains(query, "$2::text") {
201+
t.Fatalf("expected analyzer query to contain rewritten named param, got %q", query)
202+
}
203+
return &analysispb.Analysis{Params: []*analysispb.Parameter{
204+
{Number: 1, Column: &analysispb.Column{DataType: "pg_catalog.int8"}},
205+
{Number: 2, Column: &analysispb.Column{Name: "version", DataType: "text", IsNamedParam: true}},
206+
{Number: 3, Column: &analysispb.Column{DataType: "text"}},
207+
{Number: 4, Column: &analysispb.Column{DataType: "pg_catalog.timestamptz"}},
208+
}}, nil
209+
}}
210+
211+
query, err := comp.parseQuery(raw, batchParameterTypeQuery, opts.Parser{})
212+
if err != nil {
213+
t.Fatal(err)
214+
}
215+
216+
assertBatchParameterNames(t, query.Params)
217+
}
218+
219+
func TestInferQueryPreservesDistinctMySQLInListParams(t *testing.T) {
220+
t.Parallel()
221+
222+
comp, raw := newMySQLInListCompiler(t)
223+
anlys, err := comp.inferQuery(raw, mysqlInListQuery)
224+
if err != nil {
225+
t.Fatal(err)
226+
}
227+
if anlys == nil {
228+
t.Fatal("expected non-nil analysis")
229+
}
230+
if len(anlys.Parameters) != 2 {
231+
t.Fatalf("expected 2 params, got %d", len(anlys.Parameters))
232+
}
233+
234+
for i, wantNumber := range []int{1, 2} {
235+
param := anlys.Parameters[i]
236+
if param.Number != wantNumber {
237+
t.Fatalf("param %d number mismatch: got %d want %d", i, param.Number, wantNumber)
238+
}
239+
if param.Column == nil {
240+
t.Fatalf("param %d column is nil", i)
241+
}
242+
if param.Column.OriginalName != "a" {
243+
t.Fatalf("param %d original name mismatch: got %q want %q", i, param.Column.OriginalName, "a")
244+
}
245+
if param.Column.DataType == "" || param.Column.DataType == "any" {
246+
t.Fatalf("param %d type was not inferred: %+v", i, param.Column)
247+
}
248+
}
249+
}

0 commit comments

Comments
 (0)