Skip to content

Commit 4112a36

Browse files
committed
fix: managed-db issues
1 parent 931dc7f commit 4112a36

File tree

4 files changed

+206
-51
lines changed

4 files changed

+206
-51
lines changed

internal/compiler/find_params.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
144144
p.parent = node
145145

146146
case *ast.SelectStmt:
147+
if n.FromClause != nil && len(n.FromClause.Items) > 0 {
148+
if rv, ok := n.FromClause.Items[0].(*ast.RangeVar); ok {
149+
p.rangeVar = rv
150+
}
151+
}
147152
if n.LimitCount != nil {
148153
p.limitCount = n.LimitCount
149154
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package compiler
2+
3+
import (
4+
"testing"
5+
6+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
7+
)
8+
9+
func TestFindParametersSelectStmtUsesFromRangeVarForWhereParams(t *testing.T) {
10+
t.Parallel()
11+
12+
tableName := "solar_commcard_mapping"
13+
refs, errs := findParameters(&ast.SelectStmt{
14+
FromClause: &ast.List{Items: []ast.Node{&ast.RangeVar{Relname: &tableName}}},
15+
WhereClause: &ast.A_Expr{
16+
Lexpr: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "deviceId"}}}},
17+
Rexpr: &ast.ParamRef{Number: 1, Location: 1},
18+
},
19+
})
20+
if len(errs) > 0 {
21+
t.Fatalf("findParameters returned errors: %v", errs)
22+
}
23+
if len(refs) != 1 {
24+
t.Fatalf("expected 1 ref, got %d", len(refs))
25+
}
26+
if refs[0].rv == nil || refs[0].rv.Relname == nil {
27+
t.Fatal("expected ref to carry range var")
28+
}
29+
if got := *refs[0].rv.Relname; got != tableName {
30+
t.Fatalf("expected ref range var %q, got %q", tableName, got)
31+
}
32+
}

internal/compiler/resolve.go

Lines changed: 64 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,58 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
344344
})
345345
}
346346

347+
addColumnParam := func(ref paramRef, key string, location int) error {
348+
var schema, rel string
349+
// TODO: Deprecate defaultTable
350+
if defaultTable != nil {
351+
schema = defaultTable.Schema
352+
rel = defaultTable.Name
353+
}
354+
if ref.rv != nil {
355+
fqn, err := ParseTableName(ref.rv)
356+
if err != nil {
357+
return err
358+
}
359+
schema = fqn.Schema
360+
rel = fqn.Name
361+
}
362+
if schema == "" {
363+
schema = c.DefaultSchema
364+
}
365+
366+
tableMap, ok := typeMap[schema][rel]
367+
if !ok {
368+
return sqlerr.RelationNotFound(rel)
369+
}
370+
371+
if c, ok := tableMap[key]; ok {
372+
defaultP := named.NewInferredParam(key, c.IsNotNull)
373+
p, isNamed := params.FetchMerge(ref.ref.Number, defaultP)
374+
return addParam(ref, Parameter{
375+
Number: ref.ref.Number,
376+
Column: &Column{
377+
Name: p.Name(),
378+
OriginalName: c.Name,
379+
DataType: dataType(&c.Type),
380+
NotNull: p.NotNull(),
381+
Unsigned: c.IsUnsigned,
382+
IsArray: c.IsArray,
383+
ArrayDims: c.ArrayDims,
384+
Table: &ast.TableName{Schema: schema, Name: rel},
385+
Length: c.Length,
386+
IsNamedParam: isNamed,
387+
IsSqlcSlice: p.IsSqlcSlice(),
388+
},
389+
})
390+
}
391+
392+
return &sqlerr.Error{
393+
Code: "42703",
394+
Message: fmt.Sprintf("column %q does not exist", key),
395+
Location: location,
396+
}
397+
}
398+
347399
for _, ref := range args {
348400
switch n := ref.parent.(type) {
349401

@@ -434,7 +486,13 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
434486
}
435487

436488
search := tables
437-
if alias != "" {
489+
if alias == "" && ref.rv != nil {
490+
fqn, err := ParseTableName(ref.rv)
491+
if err != nil {
492+
return nil, err
493+
}
494+
search = []*ast.TableName{fqn}
495+
} else if alias != "" {
438496
if original, ok := aliasMap[alias]; ok {
439497
search = []*ast.TableName{original}
440498
} else {
@@ -704,58 +762,13 @@ func (comp *Compiler) resolveCatalogRefs(qc *QueryCatalog, rvs []*ast.RangeVar,
704762
if n.Name == nil {
705763
return nil, fmt.Errorf("*ast.ResTarget has nil name")
706764
}
707-
key := *n.Name
708-
709-
var schema, rel string
710-
// TODO: Deprecate defaultTable
711-
if defaultTable != nil {
712-
schema = defaultTable.Schema
713-
rel = defaultTable.Name
714-
}
715-
if ref.rv != nil {
716-
fqn, err := ParseTableName(ref.rv)
717-
if err != nil {
718-
return nil, err
719-
}
720-
schema = fqn.Schema
721-
rel = fqn.Name
722-
}
723-
if schema == "" {
724-
schema = c.DefaultSchema
725-
}
726-
727-
tableMap, ok := typeMap[schema][rel]
728-
if !ok {
729-
return nil, sqlerr.RelationNotFound(rel)
765+
if err := addColumnParam(ref, *n.Name, n.Location); err != nil {
766+
return nil, err
730767
}
731768

732-
if c, ok := tableMap[key]; ok {
733-
defaultP := named.NewInferredParam(key, c.IsNotNull)
734-
p, isNamed := params.FetchMerge(ref.ref.Number, defaultP)
735-
if err := addParam(ref, Parameter{
736-
Number: ref.ref.Number,
737-
Column: &Column{
738-
Name: p.Name(),
739-
OriginalName: c.Name,
740-
DataType: dataType(&c.Type),
741-
NotNull: p.NotNull(),
742-
Unsigned: c.IsUnsigned,
743-
IsArray: c.IsArray,
744-
ArrayDims: c.ArrayDims,
745-
Table: &ast.TableName{Schema: schema, Name: rel},
746-
Length: c.Length,
747-
IsNamedParam: isNamed,
748-
IsSqlcSlice: p.IsSqlcSlice(),
749-
},
750-
}); err != nil {
751-
return nil, err
752-
}
753-
} else {
754-
return nil, &sqlerr.Error{
755-
Code: "42703",
756-
Message: fmt.Sprintf("column %q does not exist", key),
757-
Location: n.Location,
758-
}
769+
case *ast.String:
770+
if err := addColumnParam(ref, n.Str, n.Pos()); err != nil {
771+
return nil, err
759772
}
760773

761774
case *ast.TypeCast:

internal/compiler/resolve_test.go

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"github.com/sqlc-dev/sqlc/internal/engine/sqlite"
88
"github.com/sqlc-dev/sqlc/internal/sql/ast"
99
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
10+
"github.com/sqlc-dev/sqlc/internal/sql/named"
1011
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
1112
)
1213

@@ -102,3 +103,107 @@ func TestResolvedFuncCallArgType(t *testing.T) {
102103
t.Fatalf("expected nil for out-of-range positional arg, got %#v", got)
103104
}
104105
}
106+
107+
func TestResolveCatalogRefsInsertTargetStringInfersColumnName(t *testing.T) {
108+
t.Parallel()
109+
110+
comp := &Compiler{parser: postgresql.NewParser(), catalog: postgresql.NewCatalog()}
111+
112+
var schema *catalog.Schema
113+
for _, s := range comp.catalog.Schemas {
114+
if s.Name == comp.catalog.DefaultSchema {
115+
schema = s
116+
break
117+
}
118+
}
119+
if schema == nil {
120+
t.Fatal("default schema not found")
121+
}
122+
123+
tableName := "solar_commcard_mapping"
124+
schema.Tables = append(schema.Tables, &catalog.Table{
125+
Rel: &ast.TableName{Schema: schema.Name, Name: tableName},
126+
Columns: []*catalog.Column{&catalog.Column{
127+
Name: "deviceId",
128+
Type: ast.TypeName{Schema: "pg_catalog", Name: "int8"},
129+
IsNotNull: true,
130+
}},
131+
})
132+
133+
rv := &ast.RangeVar{Relname: &tableName}
134+
params, err := comp.resolveCatalogRefs(nil, []*ast.RangeVar{rv}, []paramRef{{
135+
parent: &ast.String{Str: "deviceId"},
136+
rv: rv,
137+
ref: &ast.ParamRef{Number: 1},
138+
}}, named.NewParamSet(nil, true), nil)
139+
if err != nil {
140+
t.Fatalf("resolveCatalogRefs returned error: %v", err)
141+
}
142+
if len(params) != 1 {
143+
t.Fatalf("expected 1 param, got %d", len(params))
144+
}
145+
if params[0].Column == nil {
146+
t.Fatal("expected resolved column metadata")
147+
}
148+
if params[0].Column.Name != "deviceId" {
149+
t.Fatalf("expected inferred name deviceId, got %q", params[0].Column.Name)
150+
}
151+
if params[0].Column.OriginalName != "deviceId" {
152+
t.Fatalf("expected original name deviceId, got %q", params[0].Column.OriginalName)
153+
}
154+
if params[0].Column.DataType != "pg_catalog.int8" {
155+
t.Fatalf("expected data type pg_catalog.int8, got %q", params[0].Column.DataType)
156+
}
157+
if params[0].Column.Table == nil || params[0].Column.Table.Name != tableName {
158+
t.Fatalf("expected table %q, got %#v", tableName, params[0].Column.Table)
159+
}
160+
}
161+
162+
func TestResolveCatalogRefsAExprUsesScopedRangeVar(t *testing.T) {
163+
t.Parallel()
164+
165+
comp := &Compiler{parser: postgresql.NewParser(), catalog: postgresql.NewCatalog()}
166+
167+
var schema *catalog.Schema
168+
for _, s := range comp.catalog.Schemas {
169+
if s.Name == comp.catalog.DefaultSchema {
170+
schema = s
171+
break
172+
}
173+
}
174+
if schema == nil {
175+
t.Fatal("default schema not found")
176+
}
177+
178+
tableName := "solar_commcard_mapping"
179+
schema.Tables = append(schema.Tables, &catalog.Table{
180+
Rel: &ast.TableName{Schema: schema.Name, Name: tableName},
181+
Columns: []*catalog.Column{{
182+
Name: "deviceId",
183+
Type: ast.TypeName{Schema: "pg_catalog", Name: "int8"},
184+
IsNotNull: true,
185+
}},
186+
})
187+
188+
rv := &ast.RangeVar{Relname: &tableName}
189+
params, err := comp.resolveCatalogRefs(nil, []*ast.RangeVar{rv, rv}, []paramRef{{
190+
parent: &ast.A_Expr{
191+
Lexpr: &ast.ColumnRef{Fields: &ast.List{Items: []ast.Node{&ast.String{Str: "deviceId"}}}},
192+
Rexpr: &ast.ParamRef{Number: 1},
193+
},
194+
rv: rv,
195+
ref: &ast.ParamRef{Number: 1},
196+
}}, named.NewParamSet(nil, true), nil)
197+
if err != nil {
198+
t.Fatalf("resolveCatalogRefs returned error: %v", err)
199+
}
200+
if len(params) != 1 {
201+
t.Fatalf("expected 1 param, got %d", len(params))
202+
}
203+
if params[0].Column == nil {
204+
t.Fatal("expected resolved column metadata")
205+
}
206+
if params[0].Column.Name != "deviceId" {
207+
t.Fatalf("expected inferred name deviceId, got %q", params[0].Column.Name)
208+
}
209+
}

0 commit comments

Comments
 (0)