From c0b47a00dbd26e8dab4cd093c6ca68e26a624573 Mon Sep 17 00:00:00 2001 From: jack-t Date: Thu, 3 Dec 2020 12:55:32 -0500 Subject: [PATCH 1/7] Add support for CTEs --- cte.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ cte_test.go | 42 ++++++++++++++++++++++++++++++++++++++++++ select.go | 26 ++++++++++++++++++++++++++ select_test.go | 22 ++++++++++++++++++++++ 4 files changed, 135 insertions(+) create mode 100644 cte.go create mode 100644 cte_test.go diff --git a/cte.go b/cte.go new file mode 100644 index 00000000..dbb22bef --- /dev/null +++ b/cte.go @@ -0,0 +1,45 @@ +package squirrel + +import ( + "bytes" + "strings" +) + +// CTE represents a single common table expression. They are composed of an alias, a few optional components, and a data manipulation statement, though exactly what sort of statement depends on the database system you're using. MySQL, for example, only allows SELECT statements; others, like PostgreSQL, permit INSERTs, UPDATEs, and DELETEs. +// The optional components supported by this fork of Squirrel include: +// * a list of columns +// * the keyword RECURSIVE, the use of which may place additional constraints on the data manipulation statement +type CTE struct { + Alias string + ColumnList []string + Recursive bool + Expression Sqlizer +} + +// ToSql builds the SQL for a CTE +func (c CTE) ToSql() (string, []interface{}, error) { + + var buf bytes.Buffer + + if c.Recursive { + buf.WriteString("RECURSIVE ") + } + + buf.WriteString(c.Alias) + + if len(c.ColumnList) > 0 { + buf.WriteString("(") + buf.WriteString(strings.Join(c.ColumnList, ", ")) + buf.WriteString(")") + } + + buf.WriteString(" AS (") + sql, args, err := c.Expression.ToSql() + if err != nil { + return "", []interface{}{}, err + } + buf.WriteString(sql) + buf.WriteString(")") + + return buf.String(), args, nil +} diff --git a/cte_test.go b/cte_test.go new file mode 100644 index 00000000..597252f5 --- /dev/null +++ b/cte_test.go @@ -0,0 +1,42 @@ +package squirrel + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNormalCTE(t *testing.T) { + + cte := CTE{ + Alias: "cte", + ColumnList: []string{"abc", "def"}, + Recursive: false, + Expression: Select("abc", "def").From("t").Where(Eq{"abc": 1}), + } + + sql, args, err := cte.ToSql() + + assert.Equal(t, "cte(abc, def) AS (SELECT abc, def FROM t WHERE abc = ?)", sql) + assert.Equal(t, []interface{}{1}, args) + assert.Nil(t, err) + +} + +func TestRecursiveCTE(t *testing.T) { + + // this isn't usually valid SQL, but the point is to test the RECURSIVE part + cte := CTE{ + Alias: "cte", + ColumnList: []string{"abc", "def"}, + Recursive: true, + Expression: Select("abc", "def").From("t").Where(Eq{"abc": 1}), + } + + sql, args, err := cte.ToSql() + + assert.Equal(t, "RECURSIVE cte(abc, def) AS (SELECT abc, def FROM t WHERE abc = ?)", sql) + assert.Equal(t, []interface{}{1}, args) + assert.Nil(t, err) + +} diff --git a/select.go b/select.go index b585344c..6c20943d 100644 --- a/select.go +++ b/select.go @@ -13,6 +13,7 @@ type selectData struct { PlaceholderFormat PlaceholderFormat RunWith BaseRunner Prefixes []Sqlizer + CTEs []Sqlizer Options []string Columns []Sqlizer From Sqlizer @@ -78,6 +79,15 @@ func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) { sql.WriteString(" ") } + if len(d.CTEs) > 0 { + sql.WriteString("WITH ") + args, err = appendToSql(d.CTEs, sql, ", ", args) + if err != nil { + return + } + sql.WriteString(" ") + } + sql.WriteString("SELECT ") if len(d.Options) > 0 { @@ -253,6 +263,22 @@ func (b SelectBuilder) Options(options ...string) SelectBuilder { return builder.Extend(b, "Options", options).(SelectBuilder) } +// With adds a non-recursive CTE to the query. +func (b SelectBuilder) With(alias string, expr Sqlizer) SelectBuilder { + return b.WithCTE(CTE{Alias: alias, ColumnList: []string{}, Recursive: false, Expression: expr}) +} + +// WithRecursive adds a recursive CTE to the query. +func (b SelectBuilder) WithRecursive(alias string, expr Sqlizer) SelectBuilder { + return b.WithCTE(CTE{Alias: alias, ColumnList: []string{}, Recursive: true, Expression: expr}) +} + +// WithCTE adds an arbitrary Sqlizer to the query. +// The sqlizer will be sandwiched between the keyword WITH and, if there's more than one CTE, a comma. +func (b SelectBuilder) WithCTE(cte Sqlizer) SelectBuilder { + return builder.Append(b, "CTEs", cte).(SelectBuilder) +} + // Columns adds result columns to the query. func (b SelectBuilder) Columns(columns ...string) SelectBuilder { parts := make([]interface{}, 0, len(columns)) diff --git a/select_test.go b/select_test.go index aa3742fa..24495697 100644 --- a/select_test.go +++ b/select_test.go @@ -277,6 +277,28 @@ func TestSelectSubqueryInConjunctionPlaceholderNumbering(t *testing.T) { expectedSql := "SELECT * WHERE (EXISTS( SELECT a WHERE b = $1 )) AND c = $2" assert.Equal(t, expectedSql, sql) assert.Equal(t, []interface{}{1, 2}, args) +func TestOneCTE(t *testing.T) { + sql, _, err := Select("*").From("cte").With("cte", Select("abc").From("def")).ToSql() + + assert.NoError(t, err) + + assert.Equal(t, "WITH cte AS (SELECT abc FROM def) SELECT * FROM cte", sql) +} + +func TestTwoCTEs(t *testing.T) { + sql, _, err := Select("*").From("cte").With("cte", Select("abc").From("def")).With("cte2", Select("ghi").From("jkl")).ToSql() + + assert.NoError(t, err) + + assert.Equal(t, "WITH cte AS (SELECT abc FROM def), cte2 AS (SELECT ghi FROM jkl) SELECT * FROM cte", sql) +} + +func TestCTEErrorBubblesUp(t *testing.T) { + + // a SELECT with no columns raises an error + _, _, err := Select("*").From("cte").With("cte", SelectBuilder{}.From("def")).ToSql() + + assert.Error(t, err) } func ExampleSelect() { From c2ebf490507f441d1af7ac5f37f0930d7ba7f8ce Mon Sep 17 00:00:00 2001 From: tnissen375 <76869973+tnissen375@users.noreply.github.com> Date: Tue, 22 Feb 2022 21:02:59 +0100 Subject: [PATCH 2/7] Update select.go Add UNION --- select.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/select.go b/select.go index 6c20943d..697ff8ad 100644 --- a/select.go +++ b/select.go @@ -338,6 +338,16 @@ func (b SelectBuilder) CrossJoin(join string, rest ...interface{}) SelectBuilder return b.JoinClause("CROSS JOIN "+join, rest...) } +// Union adds UNION to the query. +func (b SelectBuilder) Union(join string, rest ...interface{}) SelectBuilder { + return b.JoinClause("UNION "+join, rest...) +} + +// UnionAll adds UNION ALL to the query. +func (b SelectBuilder) UnionAll(join string, rest ...interface{}) SelectBuilder { + return b.JoinClause("UNION ALL "+join, rest...) +} + // Where adds an expression to the WHERE clause of the query. // // Expressions are ANDed together in the generated SQL. From f26eba190de5f2fdc22b7180f5ee9307cce746f2 Mon Sep 17 00:00:00 2001 From: tnissen Date: Sat, 26 Feb 2022 13:15:16 +0100 Subject: [PATCH 3/7] UNION support extended --- select.go | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/select.go b/select.go index 697ff8ad..b5e26ba4 100644 --- a/select.go +++ b/select.go @@ -14,6 +14,8 @@ type selectData struct { RunWith BaseRunner Prefixes []Sqlizer CTEs []Sqlizer + Union Sqlizer + UnionAll Sqlizer Options []string Columns []Sqlizer From Sqlizer @@ -110,6 +112,22 @@ func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) { } } + if d.Union != nil { + sql.WriteString(" UNION ") + args, err = appendToSql([]Sqlizer{d.Union}, sql, "", args) + if err != nil { + return + } + } + + if d.UnionAll != nil { + sql.WriteString(" UNION ALL ") + args, err = appendToSql([]Sqlizer{d.UnionAll}, sql, "", args) + if err != nil { + return + } + } + if len(d.Joins) > 0 { sql.WriteString(" ") args, err = appendToSql(d.Joins, sql, " ", args) @@ -308,6 +326,20 @@ func (b SelectBuilder) FromSelect(from SelectBuilder, alias string) SelectBuilde return builder.Set(b, "From", Alias(from, alias)).(SelectBuilder) } +// UnionSelect sets a union SelectBuilder which removes duplicate rows +// --> UNION combines the result from multiple SELECT statements into a single result set +func (b SelectBuilder) UnionSelect(union SelectBuilder, alias string) SelectBuilder { + union = union.PlaceholderFormat(Question) + return builder.Set(b, "Union", union).(SelectBuilder) +} + +// UnionAllSelect sets a union SelectBuilder which includes all matching rows +// --> UNION combines the result from multiple SELECT statements into a single result set +func (b SelectBuilder) UnionAllSelect(union SelectBuilder, alias string) SelectBuilder { + union = union.PlaceholderFormat(Question) + return builder.Set(b, "UnionAll", union).(SelectBuilder) +} + // JoinClause adds a join clause to the query. func (b SelectBuilder) JoinClause(pred interface{}, args ...interface{}) SelectBuilder { return builder.Append(b, "Joins", newPart(pred, args...)).(SelectBuilder) @@ -338,12 +370,12 @@ func (b SelectBuilder) CrossJoin(join string, rest ...interface{}) SelectBuilder return b.JoinClause("CROSS JOIN "+join, rest...) } -// Union adds UNION to the query. +// Union adds UNION to the query. (duplicate rows are removed) func (b SelectBuilder) Union(join string, rest ...interface{}) SelectBuilder { return b.JoinClause("UNION "+join, rest...) } -// UnionAll adds UNION ALL to the query. +// UnionAll adds UNION ALL to the query. (includes all matching rows) func (b SelectBuilder) UnionAll(join string, rest ...interface{}) SelectBuilder { return b.JoinClause("UNION ALL "+join, rest...) } From 13fdb29c6ab098de3bda2ff451f21ba697227772 Mon Sep 17 00:00:00 2001 From: tnissen Date: Sat, 26 Feb 2022 14:11:20 +0100 Subject: [PATCH 4/7] Remove not used param --- select.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/select.go b/select.go index b5e26ba4..d7e9a515 100644 --- a/select.go +++ b/select.go @@ -328,14 +328,14 @@ func (b SelectBuilder) FromSelect(from SelectBuilder, alias string) SelectBuilde // UnionSelect sets a union SelectBuilder which removes duplicate rows // --> UNION combines the result from multiple SELECT statements into a single result set -func (b SelectBuilder) UnionSelect(union SelectBuilder, alias string) SelectBuilder { +func (b SelectBuilder) UnionSelect(union SelectBuilder) SelectBuilder { union = union.PlaceholderFormat(Question) return builder.Set(b, "Union", union).(SelectBuilder) } // UnionAllSelect sets a union SelectBuilder which includes all matching rows // --> UNION combines the result from multiple SELECT statements into a single result set -func (b SelectBuilder) UnionAllSelect(union SelectBuilder, alias string) SelectBuilder { +func (b SelectBuilder) UnionAllSelect(union SelectBuilder) SelectBuilder { union = union.PlaceholderFormat(Question) return builder.Set(b, "UnionAll", union).(SelectBuilder) } From d3b045dd460d1f85ca618cf407bd30859c13d0e7 Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Tue, 31 May 2022 14:19:22 -0700 Subject: [PATCH 5/7] fix (unions) Changes to add Union tests and also Get them to pass --- select.go | 24 ++++++++++++------------ select_test.go | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/select.go b/select.go index d7e9a515..db825b1f 100644 --- a/select.go +++ b/select.go @@ -112,33 +112,33 @@ func (d *selectData) toSqlRaw() (sqlStr string, args []interface{}, err error) { } } - if d.Union != nil { - sql.WriteString(" UNION ") - args, err = appendToSql([]Sqlizer{d.Union}, sql, "", args) + if len(d.Joins) > 0 { + sql.WriteString(" ") + args, err = appendToSql(d.Joins, sql, " ", args) if err != nil { return } } - if d.UnionAll != nil { - sql.WriteString(" UNION ALL ") - args, err = appendToSql([]Sqlizer{d.UnionAll}, sql, "", args) + if len(d.WhereParts) > 0 { + sql.WriteString(" WHERE ") + args, err = appendToSql(d.WhereParts, sql, " AND ", args) if err != nil { return } } - if len(d.Joins) > 0 { - sql.WriteString(" ") - args, err = appendToSql(d.Joins, sql, " ", args) + if d.Union != nil { + sql.WriteString(" UNION ") + args, err = appendToSql([]Sqlizer{d.Union}, sql, "", args) if err != nil { return } } - if len(d.WhereParts) > 0 { - sql.WriteString(" WHERE ") - args, err = appendToSql(d.WhereParts, sql, " AND ", args) + if d.UnionAll != nil { + sql.WriteString(" UNION ALL ") + args, err = appendToSql([]Sqlizer{d.UnionAll}, sql, "", args) if err != nil { return } diff --git a/select_test.go b/select_test.go index 22e7bbce..482c7aaf 100644 --- a/select_test.go +++ b/select_test.go @@ -277,6 +277,8 @@ func TestSelectSubqueryInConjunctionPlaceholderNumbering(t *testing.T) { expectedSql := "SELECT * WHERE (EXISTS( SELECT a WHERE b = $1 )) AND c = $2" assert.Equal(t, expectedSql, sql) assert.Equal(t, []interface{}{1, 2}, args) +} + func TestOneCTE(t *testing.T) { sql, _, err := Select("*").From("cte").With("cte", Select("abc").From("def")).ToSql() @@ -473,3 +475,35 @@ func ExampleSelectBuilder_ToSql() { // scan... } } + +func TestSelectBuilderUnionToSql(t *testing.T) { + multi := Select("column1", "column2"). + From("table1"). + Where(Eq{"column1": "test"}). + UnionSelect(Select("column3", "column4").From("table2").Where(Lt{"column4": 5}). + UnionSelect(Select("column5", "column6").From("table3").Where(LtOrEq{"column5": 6}))) + sql, args, err := multi.ToSql() + assert.NoError(t, err) + + expectedSql := `SELECT column1, column2 FROM table1 WHERE column1 = ? ` + + "UNION SELECT column3, column4 FROM table2 WHERE column4 < ? " + + "UNION SELECT column5, column6 FROM table3 WHERE column5 <= ?" + assert.Equal(t, expectedSql, sql) + + expectedArgs := []interface{}{"test", 5, 6} + assert.Equal(t, expectedArgs, args) + + unionAll := Select("count(true) as C"). + From("table1"). + Where(Eq{"column1": []string{"test", "tester"}}). + UnionAllSelect(Select("count(true) as C").From("table2").Where(Select("true").Prefix("NOT EXISTS(").Suffix(")").From("table3").Where(Eq{"id": 5}))) + sql, args, err = unionAll.ToSql() + assert.NoError(t, err) + + expectedSql = `SELECT count(true) as C FROM table1 WHERE column1 IN (?,?) ` + + "UNION ALL SELECT count(true) as C FROM table2 WHERE NOT EXISTS( SELECT true FROM table3 WHERE id = ? )" + assert.Equal(t, expectedSql, sql) + + expectedArgs = []interface{}{"test", "tester", 5} + assert.Equal(t, expectedArgs, args) +} From 72ee965718fb211d3f51bb9f1b49476368fdd67b Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Tue, 31 May 2022 14:24:38 -0700 Subject: [PATCH 6/7] Change query to seem more reasonable --- select_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/select_test.go b/select_test.go index 482c7aaf..7b9fa3f4 100644 --- a/select_test.go +++ b/select_test.go @@ -496,14 +496,14 @@ func TestSelectBuilderUnionToSql(t *testing.T) { unionAll := Select("count(true) as C"). From("table1"). Where(Eq{"column1": []string{"test", "tester"}}). - UnionAllSelect(Select("count(true) as C").From("table2").Where(Select("true").Prefix("NOT EXISTS(").Suffix(")").From("table3").Where(Eq{"id": 5}))) + UnionAllSelect(Select("count(true) as C").From("table2").Where(Select("true").Prefix("NOT EXISTS(").Suffix(")").From("table3").Where("id=table2.column3"))) sql, args, err = unionAll.ToSql() assert.NoError(t, err) expectedSql = `SELECT count(true) as C FROM table1 WHERE column1 IN (?,?) ` + - "UNION ALL SELECT count(true) as C FROM table2 WHERE NOT EXISTS( SELECT true FROM table3 WHERE id = ? )" + "UNION ALL SELECT count(true) as C FROM table2 WHERE NOT EXISTS( SELECT true FROM table3 WHERE id=table2.column3 )" assert.Equal(t, expectedSql, sql) - expectedArgs = []interface{}{"test", "tester", 5} + expectedArgs = []interface{}{"test", "tester"} assert.Equal(t, expectedArgs, args) } From 4ded8f44f9cad15803727ff741560dc65b2f17cb Mon Sep 17 00:00:00 2001 From: Baroukh Ovadia Date: Tue, 31 May 2022 14:28:23 -0700 Subject: [PATCH 7/7] Ensure that placeholder also works --- select_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/select_test.go b/select_test.go index 7b9fa3f4..3e99ac5a 100644 --- a/select_test.go +++ b/select_test.go @@ -493,6 +493,13 @@ func TestSelectBuilderUnionToSql(t *testing.T) { expectedArgs := []interface{}{"test", 5, 6} assert.Equal(t, expectedArgs, args) + sql, _, err = multi.PlaceholderFormat(Dollar).ToSql() + assert.NoError(t, err) + expectedSql = `SELECT column1, column2 FROM table1 WHERE column1 = $1 ` + + "UNION SELECT column3, column4 FROM table2 WHERE column4 < $2 " + + "UNION SELECT column5, column6 FROM table3 WHERE column5 <= $3" + assert.Equal(t, expectedSql, sql) + unionAll := Select("count(true) as C"). From("table1"). Where(Eq{"column1": []string{"test", "tester"}}).