@@ -53,6 +53,137 @@ func compatibleParamTypes(a, b *Column) bool {
5353 a .ArrayDims == b .ArrayDims
5454}
5555
56+ func sameTypeName (a , b * ast.TypeName ) bool {
57+ if a == nil || b == nil {
58+ return a == nil && b == nil
59+ }
60+ return a .Catalog == b .Catalog && a .Schema == b .Schema && a .Name == b .Name
61+ }
62+
63+ func matchingFuncCallOverloads (c * catalog.Catalog , call * ast.FuncCall ) []catalog.Function {
64+ funs , err := c .ListFuncsByName (call .Func )
65+ if err != nil {
66+ return nil
67+ }
68+
69+ var positional []ast.Node
70+ var named []* ast.NamedArgExpr
71+ if call .Args != nil {
72+ for _ , arg := range call .Args .Items {
73+ if narg , ok := arg .(* ast.NamedArgExpr ); ok {
74+ named = append (named , narg )
75+ continue
76+ }
77+ if len (named ) > 0 {
78+ return nil
79+ }
80+ positional = append (positional , arg )
81+ }
82+ }
83+
84+ var matches []catalog.Function
85+ for _ , fun := range funs {
86+ args := fun .InArgs ()
87+ var defaults int
88+ var variadic bool
89+ known := map [string ]struct {}{}
90+ for _ , arg := range args {
91+ if arg .HasDefault {
92+ defaults += 1
93+ }
94+ if arg .Mode == ast .FuncParamVariadic {
95+ variadic = true
96+ defaults += 1
97+ }
98+ if arg .Name != "" {
99+ known [arg .Name ] = struct {}{}
100+ }
101+ }
102+
103+ argc := len (named ) + len (positional )
104+ if variadic {
105+ if argc < (len (args ) - defaults ) {
106+ continue
107+ }
108+ } else {
109+ if argc > len (args ) || argc < (len (args )- defaults ) {
110+ continue
111+ }
112+ }
113+
114+ var unknownArgName bool
115+ for _ , expr := range named {
116+ if expr .Name != nil {
117+ if _ , found := known [* expr .Name ]; ! found {
118+ unknownArgName = true
119+ }
120+ }
121+ }
122+ if unknownArgName {
123+ continue
124+ }
125+
126+ matches = append (matches , fun )
127+ }
128+
129+ return matches
130+ }
131+
132+ func stableFuncCallArgType (c * catalog.Catalog , call * ast.FuncCall , argIndex int , argName string ) * ast.TypeName {
133+ var stable * ast.TypeName
134+ var seen bool
135+
136+ for _ , fun := range matchingFuncCallOverloads (c , call ) {
137+ args := fun .InArgs ()
138+ var current * ast.TypeName
139+ if argName == "" {
140+ if argIndex >= len (args ) {
141+ return nil
142+ }
143+ current = args [argIndex ].Type
144+ } else {
145+ for _ , arg := range args {
146+ if arg .Name == argName {
147+ current = arg .Type
148+ break
149+ }
150+ }
151+ if current == nil {
152+ return nil
153+ }
154+ }
155+
156+ if ! seen {
157+ stable = current
158+ seen = true
159+ continue
160+ }
161+ if ! sameTypeName (stable , current ) {
162+ return nil
163+ }
164+ }
165+
166+ return stable
167+ }
168+
169+ func resolvedFuncCallArgType (fun * catalog.Function , argIndex int , argName string ) * ast.TypeName {
170+ if fun == nil {
171+ return nil
172+ }
173+ if argName == "" {
174+ if argIndex < len (fun .Args ) {
175+ return fun .Args [argIndex ].Type
176+ }
177+ return nil
178+ }
179+ for _ , arg := range fun .Args {
180+ if arg .Name == argName {
181+ return arg .Type
182+ }
183+ }
184+ return nil
185+ }
186+
56187func mergeResolvedParam (existing , incoming Parameter ) Parameter {
57188 if existing .Column == nil {
58189 return incoming
@@ -93,8 +224,8 @@ func mergeResolvedParam(existing, incoming Parameter) Parameter {
93224
94225func (comp * Compiler ) incompatibleParamRefError (ref paramRef , existing , incoming Parameter ) error {
95226 return & sqlerr.Error {
96- Code : "42P08" ,
97- Message : fmt .Sprintf (
227+ Code : "42P08" ,
228+ Message : fmt .Sprintf (
98229 "parameter $%d has incompatible types: %s, %s" ,
99230 ref .ref .Number ,
100231 comp .paramTypeString (existing .Column ),
@@ -182,6 +313,10 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
182313
183314 var a []Parameter
184315 seen := map [int ]int {}
316+ paramCounts := map [int ]int {}
317+ for _ , ref := range args {
318+ paramCounts [ref .ref .Number ] += 1
319+ }
185320
186321 addParam := func (ref paramRef , p Parameter ) error {
187322 if idx , ok := seen [p .Number ]; ok {
@@ -424,8 +559,8 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
424559 }
425560
426561 case * ast.FuncCall :
427- fun , err := c .ResolveFuncCall (n )
428- if err != nil {
562+ fun , resolveErr := c .ResolveFuncCall (n )
563+ if resolveErr != nil {
429564 // Synthesize a function on the fly to avoid returning with an error
430565 // for an unknown Postgres function (e.g. defined in an extension)
431566 var args []* catalog.Argument
@@ -503,22 +638,20 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
503638 if argName == "" {
504639 if i < len (fun .Args ) {
505640 paramName = fun .Args [i ].Name
506- paramType = fun .Args [i ].Type
507641 }
508642 } else {
509643 paramName = argName
510- for _ , arg := range fun .Args {
511- if arg .Name == argName {
512- paramType = arg .Type
513- }
514- }
515- if paramType == nil {
516- panic (fmt .Sprintf ("named argument %s has no type" , paramName ))
517- }
518644 }
519645 if paramName == "" {
520646 paramName = funcName
521647 }
648+ if resolveErr == nil {
649+ if paramCounts [ref .ref .Number ] > 1 {
650+ paramType = stableFuncCallArgType (c , n , i , argName )
651+ } else {
652+ paramType = resolvedFuncCallArgType (fun , i , argName )
653+ }
654+ }
522655 if paramType == nil {
523656 paramType = & ast.TypeName {Name : "" }
524657 }
0 commit comments