Skip to content

Commit f6af815

Browse files
committed
fix: some type bugs
1 parent f758c7e commit f6af815

File tree

2 files changed

+182
-13
lines changed

2 files changed

+182
-13
lines changed

internal/compiler/resolve.go

Lines changed: 146 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,137 @@ func compatibleParamTypes(a, b *Column) bool {
5353
a.ArrayDims == b.ArrayDims
5454
}
5555

56+
func sameTypeName(a, b *ast.TypeName) bool {
57+
if a == nil || b == nil {
58+
return a == nil && b == nil
59+
}
60+
return a.Catalog == b.Catalog && a.Schema == b.Schema && a.Name == b.Name
61+
}
62+
63+
func matchingFuncCallOverloads(c *catalog.Catalog, call *ast.FuncCall) []catalog.Function {
64+
funs, err := c.ListFuncsByName(call.Func)
65+
if err != nil {
66+
return nil
67+
}
68+
69+
var positional []ast.Node
70+
var named []*ast.NamedArgExpr
71+
if call.Args != nil {
72+
for _, arg := range call.Args.Items {
73+
if narg, ok := arg.(*ast.NamedArgExpr); ok {
74+
named = append(named, narg)
75+
continue
76+
}
77+
if len(named) > 0 {
78+
return nil
79+
}
80+
positional = append(positional, arg)
81+
}
82+
}
83+
84+
var matches []catalog.Function
85+
for _, fun := range funs {
86+
args := fun.InArgs()
87+
var defaults int
88+
var variadic bool
89+
known := map[string]struct{}{}
90+
for _, arg := range args {
91+
if arg.HasDefault {
92+
defaults += 1
93+
}
94+
if arg.Mode == ast.FuncParamVariadic {
95+
variadic = true
96+
defaults += 1
97+
}
98+
if arg.Name != "" {
99+
known[arg.Name] = struct{}{}
100+
}
101+
}
102+
103+
argc := len(named) + len(positional)
104+
if variadic {
105+
if argc < (len(args) - defaults) {
106+
continue
107+
}
108+
} else {
109+
if argc > len(args) || argc < (len(args)-defaults) {
110+
continue
111+
}
112+
}
113+
114+
var unknownArgName bool
115+
for _, expr := range named {
116+
if expr.Name != nil {
117+
if _, found := known[*expr.Name]; !found {
118+
unknownArgName = true
119+
}
120+
}
121+
}
122+
if unknownArgName {
123+
continue
124+
}
125+
126+
matches = append(matches, fun)
127+
}
128+
129+
return matches
130+
}
131+
132+
func stableFuncCallArgType(c *catalog.Catalog, call *ast.FuncCall, argIndex int, argName string) *ast.TypeName {
133+
var stable *ast.TypeName
134+
var seen bool
135+
136+
for _, fun := range matchingFuncCallOverloads(c, call) {
137+
args := fun.InArgs()
138+
var current *ast.TypeName
139+
if argName == "" {
140+
if argIndex >= len(args) {
141+
return nil
142+
}
143+
current = args[argIndex].Type
144+
} else {
145+
for _, arg := range args {
146+
if arg.Name == argName {
147+
current = arg.Type
148+
break
149+
}
150+
}
151+
if current == nil {
152+
return nil
153+
}
154+
}
155+
156+
if !seen {
157+
stable = current
158+
seen = true
159+
continue
160+
}
161+
if !sameTypeName(stable, current) {
162+
return nil
163+
}
164+
}
165+
166+
return stable
167+
}
168+
169+
func resolvedFuncCallArgType(fun *catalog.Function, argIndex int, argName string) *ast.TypeName {
170+
if fun == nil {
171+
return nil
172+
}
173+
if argName == "" {
174+
if argIndex < len(fun.Args) {
175+
return fun.Args[argIndex].Type
176+
}
177+
return nil
178+
}
179+
for _, arg := range fun.Args {
180+
if arg.Name == argName {
181+
return arg.Type
182+
}
183+
}
184+
return nil
185+
}
186+
56187
func mergeResolvedParam(existing, incoming Parameter) Parameter {
57188
if existing.Column == nil {
58189
return incoming
@@ -93,8 +224,8 @@ func mergeResolvedParam(existing, incoming Parameter) Parameter {
93224

94225
func (comp *Compiler) incompatibleParamRefError(ref paramRef, existing, incoming Parameter) error {
95226
return &sqlerr.Error{
96-
Code: "42P08",
97-
Message: fmt.Sprintf(
227+
Code: "42P08",
228+
Message: fmt.Sprintf(
98229
"parameter $%d has incompatible types: %s, %s",
99230
ref.ref.Number,
100231
comp.paramTypeString(existing.Column),
@@ -182,6 +313,10 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
182313

183314
var a []Parameter
184315
seen := map[int]int{}
316+
paramCounts := map[int]int{}
317+
for _, ref := range args {
318+
paramCounts[ref.ref.Number] += 1
319+
}
185320

186321
addParam := func(ref paramRef, p Parameter) error {
187322
if idx, ok := seen[p.Number]; ok {
@@ -424,8 +559,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
424559
}
425560

426561
case *ast.FuncCall:
427-
fun, err := c.ResolveFuncCall(n)
428-
if err != nil {
562+
fun, resolveErr := c.ResolveFuncCall(n)
563+
if resolveErr != nil {
429564
// Synthesize a function on the fly to avoid returning with an error
430565
// for an unknown Postgres function (e.g. defined in an extension)
431566
var args []*catalog.Argument
@@ -503,22 +638,20 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
503638
if argName == "" {
504639
if i < len(fun.Args) {
505640
paramName = fun.Args[i].Name
506-
paramType = fun.Args[i].Type
507641
}
508642
} else {
509643
paramName = argName
510-
for _, arg := range fun.Args {
511-
if arg.Name == argName {
512-
paramType = arg.Type
513-
}
514-
}
515-
if paramType == nil {
516-
panic(fmt.Sprintf("named argument %s has no type", paramName))
517-
}
518644
}
519645
if paramName == "" {
520646
paramName = funcName
521647
}
648+
if resolveErr == nil {
649+
if paramCounts[ref.ref.Number] > 1 {
650+
paramType = stableFuncCallArgType(c, n, i, argName)
651+
} else {
652+
paramType = resolvedFuncCallArgType(fun, i, argName)
653+
}
654+
}
522655
if paramType == nil {
523656
paramType = &ast.TypeName{Name: ""}
524657
}

internal/compiler/resolve_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
77
"github.com/sqlc-dev/sqlc/internal/engine/sqlite"
88
"github.com/sqlc-dev/sqlc/internal/sql/ast"
9+
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
910
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
1011
)
1112

@@ -66,3 +67,38 @@ func TestIncompatibleParamRefErrorFormatsTypeNames(t *testing.T) {
6667
t.Fatalf("unexpected message: %q", sqlErr.Message)
6768
}
6869
}
70+
71+
func TestMergeResolvedParamKeepsFirstNameForCompatibleTypes(t *testing.T) {
72+
t.Parallel()
73+
74+
merged := mergeResolvedParam(
75+
Parameter{Number: 1, Column: &Column{Name: "user", DataType: "text"}},
76+
Parameter{Number: 1, Column: &Column{Name: "student_user", DataType: "text"}},
77+
)
78+
79+
if merged.Column == nil {
80+
t.Fatal("expected merged column")
81+
}
82+
if merged.Column.Name != "user" {
83+
t.Fatalf("expected first inferred name to win, got %q", merged.Column.Name)
84+
}
85+
}
86+
87+
func TestResolvedFuncCallArgType(t *testing.T) {
88+
t.Parallel()
89+
90+
fun := &catalog.Function{Args: []*catalog.Argument{
91+
{Name: "lhs", Type: &ast.TypeName{Name: "int8"}},
92+
{Name: "rhs", Type: &ast.TypeName{Name: "text"}},
93+
}}
94+
95+
if got := resolvedFuncCallArgType(fun, 0, ""); got == nil || got.Name != "int8" {
96+
t.Fatalf("expected positional arg type int8, got %#v", got)
97+
}
98+
if got := resolvedFuncCallArgType(fun, 0, "rhs"); got == nil || got.Name != "text" {
99+
t.Fatalf("expected named arg type text, got %#v", got)
100+
}
101+
if got := resolvedFuncCallArgType(fun, 2, ""); got != nil {
102+
t.Fatalf("expected nil for out-of-range positional arg, got %#v", got)
103+
}
104+
}

0 commit comments

Comments
 (0)