Skip to content

Commit 107457e

Browse files
committed
temp ci commit will fix
1 parent e5dfa2b commit 107457e

File tree

3 files changed

+244
-44
lines changed

3 files changed

+244
-44
lines changed

internal/compiler/parse_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,25 @@ const mysqlInListQuery = `/* name: FooByList :many */
5454
SELECT a, b FROM foo WHERE foo.a IN (?, ?);
5555
`
5656

57+
const starExpansionSeriesSchema = `
58+
CREATE TABLE alertreport (
59+
eventdate date
60+
);
61+
`
62+
63+
const starExpansionSeriesQuery = `-- name: CountAlertReportBy :many
64+
select DATE_TRUNC($1,ts)::text as datetime,coalesce(count,0) as count from
65+
(
66+
SELECT DATE_TRUNC($1,eventdate) as hr ,count(*)
67+
FROM alertreport
68+
where eventdate between $2 and $3
69+
GROUP BY 1
70+
) AS cnt
71+
right outer join ( SELECT * FROM generate_series ( $2, $3, CONCAT('1 ',$1)::interval) AS ts ) as dte
72+
on DATE_TRUNC($1, ts ) = cnt.hr
73+
order by 1 asc;
74+
`
75+
5776
type stubAnalyzer struct {
5877
analyze func(context.Context, ast.Node, string, []string, *named.ParamSet) (*analysispb.Analysis, error)
5978
}
@@ -126,6 +145,36 @@ func newMySQLInListCompiler(t *testing.T) (*Compiler, *ast.RawStmt) {
126145
}, stmts[0].Raw
127146
}
128147

148+
func newStarExpansionSeriesCompiler(t *testing.T) (*Compiler, *ast.RawStmt) {
149+
t.Helper()
150+
151+
parser := postgresql.NewParser()
152+
catalog := postgresql.NewCatalog()
153+
154+
schema, err := parser.Parse(strings.NewReader(starExpansionSeriesSchema))
155+
if err != nil {
156+
t.Fatal(err)
157+
}
158+
if err := catalog.Build(schema); err != nil {
159+
t.Fatal(err)
160+
}
161+
162+
stmts, err := parser.Parse(strings.NewReader(starExpansionSeriesQuery))
163+
if err != nil {
164+
t.Fatal(err)
165+
}
166+
if len(stmts) != 1 {
167+
t.Fatalf("expected 1 statement, got %d", len(stmts))
168+
}
169+
170+
return &Compiler{
171+
conf: config.SQL{Engine: config.EnginePostgreSQL},
172+
parser: parser,
173+
catalog: catalog,
174+
selector: newDefaultSelector(),
175+
}, stmts[0].Raw
176+
}
177+
129178
func assertBatchParameterNames(t *testing.T, params []Parameter) {
130179
t.Helper()
131180

@@ -168,6 +217,40 @@ func assertBatchParameterNames(t *testing.T, params []Parameter) {
168217
}
169218
}
170219

220+
func assertStarExpansionSeriesParameterNames(t *testing.T, params []Parameter) {
221+
t.Helper()
222+
223+
checks := []struct {
224+
idx int
225+
number int
226+
name string
227+
typ string
228+
}{
229+
{idx: 0, number: 1, name: "date_trunc", typ: "text"},
230+
{idx: 1, number: 2, name: "eventdate", typ: "date"},
231+
{idx: 2, number: 3, name: "eventdate", typ: "date"},
232+
}
233+
if len(params) != len(checks) {
234+
t.Fatalf("expected %d params, got %d", len(checks), len(params))
235+
}
236+
237+
for _, check := range checks {
238+
param := params[check.idx]
239+
if param.Number != check.number {
240+
t.Fatalf("param %d number mismatch: got %d want %d", check.idx, param.Number, check.number)
241+
}
242+
if param.Column == nil {
243+
t.Fatalf("param %d column is nil", check.idx)
244+
}
245+
if param.Column.Name != check.name {
246+
t.Fatalf("param %d name mismatch: got %q want %q", check.idx, param.Column.Name, check.name)
247+
}
248+
if param.Column.DataType != check.typ && param.Column.DataType != "pg_catalog."+check.typ {
249+
t.Fatalf("param %d type mismatch: got %q want %q or %q", check.idx, param.Column.DataType, check.typ, "pg_catalog."+check.typ)
250+
}
251+
}
252+
}
253+
171254
func TestInferQueryPreservesInsertSelectParamNamesWithCTEAndMixedParams(t *testing.T) {
172255
t.Parallel()
173256

@@ -247,3 +330,38 @@ func TestInferQueryPreservesDistinctMySQLInListParams(t *testing.T) {
247330
}
248331
}
249332
}
333+
334+
func TestInferQueryPreservesStarExpansionSeriesParamNames(t *testing.T) {
335+
t.Parallel()
336+
337+
comp, raw := newStarExpansionSeriesCompiler(t)
338+
anlys, err := comp.inferQuery(raw, starExpansionSeriesQuery)
339+
if err != nil {
340+
t.Fatal(err)
341+
}
342+
if anlys == nil {
343+
t.Fatal("expected non-nil analysis")
344+
}
345+
346+
assertStarExpansionSeriesParameterNames(t, anlys.Parameters)
347+
}
348+
349+
func TestParseQueryManagedDBPreservesStarExpansionSeriesParamNames(t *testing.T) {
350+
t.Parallel()
351+
352+
comp, raw := newStarExpansionSeriesCompiler(t)
353+
comp.analyzer = stubAnalyzer{analyze: func(_ context.Context, _ ast.Node, _ string, _ []string, _ *named.ParamSet) (*analysispb.Analysis, error) {
354+
return &analysispb.Analysis{Params: []*analysispb.Parameter{
355+
{Number: 1, Column: &analysispb.Column{DataType: "pg_catalog.text"}},
356+
{Number: 2, Column: &analysispb.Column{DataType: "pg_catalog.date"}},
357+
{Number: 3, Column: &analysispb.Column{DataType: "pg_catalog.date"}},
358+
}}, nil
359+
}}
360+
361+
query, err := comp.parseQuery(raw, starExpansionSeriesQuery, opts.Parser{})
362+
if err != nil {
363+
t.Fatal(err)
364+
}
365+
366+
assertStarExpansionSeriesParameterNames(t, query.Params)
367+
}

internal/compiler/resolve.go

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,72 @@ func (comp *Compiler) incompatibleParamRefError(ref paramRef, existing, incoming
104104
}
105105
}
106106

107+
func sameTypeName(a, b *ast.TypeName) bool {
108+
if a == nil || b == nil {
109+
return a == nil && b == nil
110+
}
111+
return a.Catalog == b.Catalog &&
112+
a.Schema == b.Schema &&
113+
a.Name == b.Name &&
114+
arrayDims(a) == arrayDims(b)
115+
}
116+
117+
func funcCallArg(fn catalog.Function, idx int, namedArg string) *catalog.Argument {
118+
args := fn.InArgs()
119+
if namedArg != "" {
120+
for _, arg := range args {
121+
if arg.Name == namedArg {
122+
return arg
123+
}
124+
}
125+
return nil
126+
}
127+
if idx < 0 || idx >= len(args) {
128+
return nil
129+
}
130+
return args[idx]
131+
}
132+
133+
func funcCallArgMetadata(funcs []catalog.Function, idx int, namedArg string) (string, *ast.TypeName) {
134+
var (
135+
found bool
136+
name string
137+
nameConsistent = true
138+
typ *ast.TypeName
139+
typeConsistent = true
140+
)
141+
142+
for _, fn := range funcs {
143+
arg := funcCallArg(fn, idx, namedArg)
144+
if arg == nil {
145+
continue
146+
}
147+
if !found {
148+
found = true
149+
name = arg.Name
150+
typ = arg.Type
151+
continue
152+
}
153+
if name != arg.Name {
154+
nameConsistent = false
155+
}
156+
if !sameTypeName(typ, arg.Type) {
157+
typeConsistent = false
158+
}
159+
}
160+
161+
if !found {
162+
return "", nil
163+
}
164+
if !nameConsistent {
165+
name = ""
166+
}
167+
if !typeConsistent {
168+
typ = nil
169+
}
170+
return name, typ
171+
}
172+
107173
func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar, args []paramRef, params *named.ParamSet, embeds rewrite.EmbedSet) ([]Parameter, error) {
108174
c := comp.catalog
109175

@@ -519,7 +585,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
519585
}
520586

521587
case *ast.FuncCall:
522-
fun, err := c.ResolveFuncCall(n)
588+
funcs, err := c.ResolveFuncCalls(n)
589+
var fun *catalog.Function
523590
if err != nil {
524591
// Synthesize a function on the fly to avoid returning with an error
525592
// for an unknown Postgres function (e.g. defined in an extension)
@@ -534,6 +601,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
534601
Args: args,
535602
ReturnType: &ast.TypeName{Name: "any"},
536603
}
604+
funcs = []catalog.Function{*fun}
605+
} else {
606+
fun = &funcs[0]
537607
}
538608

539609
var added bool
@@ -592,24 +662,9 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
592662
continue
593663
}
594664

595-
var paramName string
596-
var paramType *ast.TypeName
597-
598-
if argName == "" {
599-
if i < len(fun.Args) {
600-
paramName = fun.Args[i].Name
601-
paramType = fun.Args[i].Type
602-
}
603-
} else {
665+
paramName, paramType := funcCallArgMetadata(funcs, i, argName)
666+
if argName != "" {
604667
paramName = argName
605-
for _, arg := range fun.Args {
606-
if arg.Name == argName {
607-
paramType = arg.Type
608-
}
609-
}
610-
if paramType == nil {
611-
panic(fmt.Sprintf("named argument %s has no type", paramName))
612-
}
613668
}
614669
if paramName == "" {
615670
paramName = funcName

internal/sql/catalog/public.go

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,37 +32,49 @@ func (c *Catalog) ListFuncsByName(rel *ast.FuncName) ([]Function, error) {
3232
return funcs, nil
3333
}
3434

35-
func (c *Catalog) ResolveFuncCall(call *ast.FuncCall) (*Function, error) {
35+
func splitFuncCallArgs(call *ast.FuncCall) ([]ast.Node, []*ast.NamedArgExpr, error) {
36+
var positional []ast.Node
37+
var named []*ast.NamedArgExpr
38+
39+
if call.Args == nil {
40+
return positional, named, nil
41+
}
42+
43+
for _, arg := range call.Args.Items {
44+
if narg, ok := arg.(*ast.NamedArgExpr); ok {
45+
named = append(named, narg)
46+
continue
47+
}
48+
49+
// The mixed notation combines positional and named notation.
50+
// However, as already mentioned, named arguments cannot precede
51+
// positional arguments.
52+
if len(named) > 0 {
53+
return nil, nil, &sqlerr.Error{
54+
Code: "",
55+
Message: "positional argument cannot follow named argument",
56+
Location: call.Pos(),
57+
}
58+
}
59+
positional = append(positional, arg)
60+
}
61+
62+
return positional, named, nil
63+
}
64+
65+
func (c *Catalog) ResolveFuncCalls(call *ast.FuncCall) ([]Function, error) {
3666
// Do not validate unknown functions
3767
funs, err := c.ListFuncsByName(call.Func)
3868
if err != nil || len(funs) == 0 {
3969
return nil, sqlerr.FunctionNotFound(call.Func.Name)
4070
}
4171

42-
// https://www.postgresql.org/docs/current/sql-syntax-calling-funcs.html
43-
var positional []ast.Node
44-
var named []*ast.NamedArgExpr
45-
46-
if call.Args != nil {
47-
for _, arg := range call.Args.Items {
48-
if narg, ok := arg.(*ast.NamedArgExpr); ok {
49-
named = append(named, narg)
50-
} else {
51-
// The mixed notation combines positional and named notation.
52-
// However, as already mentioned, named arguments cannot precede
53-
// positional arguments.
54-
if len(named) > 0 {
55-
return nil, &sqlerr.Error{
56-
Code: "",
57-
Message: "positional argument cannot follow named argument",
58-
Location: call.Pos(),
59-
}
60-
}
61-
positional = append(positional, arg)
62-
}
63-
}
72+
positional, named, err := splitFuncCallArgs(call)
73+
if err != nil {
74+
return nil, err
6475
}
6576

77+
var matches []Function
6678
for _, fun := range funs {
6779
args := fun.InArgs()
6880
var defaults int
@@ -94,7 +106,6 @@ func (c *Catalog) ResolveFuncCall(call *ast.FuncCall) (*Function, error) {
94106
}
95107
}
96108

97-
// Validate that the provided named arguments exist in the function
98109
var unknownArgName bool
99110
for _, expr := range named {
100111
if expr.Name != nil {
@@ -107,11 +118,19 @@ func (c *Catalog) ResolveFuncCall(call *ast.FuncCall) (*Function, error) {
107118
continue
108119
}
109120

110-
return &fun, nil
121+
matches = append(matches, fun)
122+
}
123+
124+
if len(matches) > 0 {
125+
return matches, nil
111126
}
112127

113128
var sig []string
114-
for range call.Args.Items {
129+
argCount := 0
130+
if call.Args != nil {
131+
argCount = len(call.Args.Items)
132+
}
133+
for range argCount {
115134
sig = append(sig, "unknown")
116135
}
117136

@@ -123,6 +142,14 @@ func (c *Catalog) ResolveFuncCall(call *ast.FuncCall) (*Function, error) {
123142
}
124143
}
125144

145+
func (c *Catalog) ResolveFuncCall(call *ast.FuncCall) (*Function, error) {
146+
matches, err := c.ResolveFuncCalls(call)
147+
if err != nil {
148+
return nil, err
149+
}
150+
return &matches[0], nil
151+
}
152+
126153
func (c *Catalog) GetTable(rel *ast.TableName) (Table, error) {
127154
_, table, err := c.getTable(rel)
128155
if table == nil {

0 commit comments

Comments
 (0)