Skip to content

Commit 26f6650

Browse files
committed
Analyse OUT arguments in PostgreSQL procedures
1 parent 21e6557 commit 26f6650

File tree

2 files changed

+97
-22
lines changed

2 files changed

+97
-22
lines changed

internal/compiler/output_columns.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,23 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er
5959

6060
targets := &ast.List{}
6161
switch n := node.(type) {
62+
case *ast.CallStmt:
63+
fun, err := qc.catalog.ResolveFuncCall(n.FuncCall)
64+
if err != nil {
65+
return nil, err
66+
}
67+
var cols []*Column
68+
for _, arg := range fun.Args {
69+
if arg.Mode == ast.FuncParamOut || arg.Mode == ast.FuncParamInOut || arg.Mode == ast.FuncParamTable {
70+
name := arg.Name
71+
typeName := arg.Type.Name
72+
if arg.Type.Names != nil && len(arg.Type.Names.Items) > 0 {
73+
typeName = astutils.Join(arg.Type.Names, ".")
74+
}
75+
cols = append(cols, &Column{Name: name, DataType: typeName, NotNull: false})
76+
}
77+
}
78+
return cols, nil
6279
case *ast.DeleteStmt:
6380
targets = n.ReturningList
6481
case *ast.InsertStmt:

internal/sql/catalog/public.go

Lines changed: 80 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,49 +64,107 @@ func (c *Catalog) ResolveFuncCall(call *ast.FuncCall) (*Function, error) {
6464
}
6565

6666
for _, fun := range funs {
67-
args := fun.InArgs()
68-
var defaults int
69-
var variadic bool
67+
// Separate input and output args from the function signature
68+
inArgs := fun.InArgs()
69+
outArgs := fun.OutArgs()
70+
71+
// Build known argument names from all parameters (IN/OUT/INOUT/etc.)
7072
known := map[string]struct{}{}
71-
for _, arg := range args {
73+
for _, a := range fun.Args {
74+
if a.Name != "" {
75+
known[a.Name] = struct{}{}
76+
}
77+
}
78+
79+
// Count defaults and whether the last IN arg is variadic
80+
var defaultsIn int
81+
var variadic bool
82+
for _, arg := range inArgs {
7283
if arg.HasDefault {
73-
defaults += 1
84+
defaultsIn += 1
7485
}
7586
if arg.Mode == ast.FuncParamVariadic {
7687
variadic = true
77-
defaults += 1
88+
// Treat the variadic parameter like having a default for count checks
89+
defaultsIn += 1
7890
}
79-
if arg.Name != "" {
80-
known[arg.Name] = struct{}{}
91+
}
92+
93+
// Tally named arguments provided by the call: which refer to IN vs OUT names
94+
var namedIn, namedOut int
95+
var unknownArgName bool
96+
for _, expr := range named {
97+
if expr.Name != nil {
98+
name := *expr.Name
99+
if _, ok := known[name]; !ok {
100+
unknownArgName = true
101+
continue
102+
}
103+
// Classify whether the provided named arg matches an IN or OUT param
104+
var isIn bool
105+
for _, a := range inArgs {
106+
if a.Name == name {
107+
isIn = true
108+
break
109+
}
110+
}
111+
if isIn {
112+
namedIn += 1
113+
} else {
114+
// If not IN, treat it as an OUT placeholder/name
115+
namedOut += 1
116+
}
81117
}
82118
}
119+
if unknownArgName {
120+
// Provided a named argument that doesn't exist in the signature
121+
continue
122+
}
83123

124+
// Positional arguments always come first (we validated above that
125+
// positional cannot follow named). They fill IN parameters first; any
126+
// excess positional arguments are treated as placeholders for OUT params
127+
var posFillIn = len(positional)
128+
if posFillIn > len(inArgs) {
129+
posFillIn = len(inArgs)
130+
}
131+
// Count how many IN arguments are provided (positional for IN + named for IN)
132+
inProvided := posFillIn + namedIn
133+
134+
// Validate IN argument counts against the signature considering defaults/variadic
84135
if variadic {
85-
if (len(named) + len(positional)) < (len(args) - defaults) {
136+
if inProvided < (len(inArgs) - defaultsIn) {
86137
continue
87138
}
88139
} else {
89-
if (len(named) + len(positional)) > len(args) {
140+
if inProvided > len(inArgs) {
90141
continue
91142
}
92-
if (len(named) + len(positional)) < (len(args) - defaults) {
143+
if inProvided < (len(inArgs) - defaultsIn) {
93144
continue
94145
}
95146
}
96147

97-
// Validate that the provided named arguments exist in the function
98-
var unknownArgName bool
99-
for _, expr := range named {
100-
if expr.Name != nil {
101-
if _, found := known[*expr.Name]; !found {
102-
unknownArgName = true
103-
}
104-
}
148+
// Validate OUT placeholders. These are only valid in procedure calls.
149+
// For normal function invocation, callers cannot pass values for OUT params.
150+
posOut := 0
151+
if len(positional) > len(inArgs) {
152+
posOut = len(positional) - len(inArgs)
105153
}
106-
if unknownArgName {
107-
continue
154+
outProvided := posOut + namedOut
155+
if fun.ReturnType == nil {
156+
// Procedure: allow passing placeholders for OUT params, but not more than available
157+
if outProvided > len(outArgs) {
158+
continue
159+
}
160+
} else {
161+
// Function: do not allow any OUT placeholders
162+
if outProvided > 0 {
163+
continue
164+
}
108165
}
109166

167+
// All checks passed for this candidate
110168
return &fun, nil
111169
}
112170

@@ -117,7 +175,7 @@ func (c *Catalog) ResolveFuncCall(call *ast.FuncCall) (*Function, error) {
117175

118176
return nil, &sqlerr.Error{
119177
Code: "42883",
120-
Message: fmt.Sprintf("function %s(%s) does not exist", call.Func.Name, strings.Join(sig, ", ")),
178+
Message: fmt.Sprintf("CODE 42883: function %s(%s) does not exist", call.Func.Name, strings.Join(sig, ", ")),
121179
Location: call.Pos(),
122180
// Hint: "No function matches the given name and argument types. You might need to add explicit type casts.",
123181
}

0 commit comments

Comments
 (0)