Skip to content

Minor query builder improvements #1087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion queries/qm/query_mods.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,14 @@ func Distinct(clause string) QueryMod {
}

type withQueryMod struct {
alias string
clause string
args []interface{}
}

// Apply implements QueryMod.Apply.
func (qm withQueryMod) Apply(q *queries.Query) {
queries.AppendWith(q, qm.clause, qm.args...)
queries.AppendWith(q, qm.alias, qm.clause, qm.args...)
}

// With allows you to pass in a Common Table Expression clause (and args)
Expand All @@ -211,6 +212,17 @@ func With(clause string, args ...interface{}) QueryMod {
}
}

// WithSubquery allows you to generate a Common Table Expression using a query
// object to populate the CTE
func WithSubquery(alias string, q *queries.Query) QueryMod {
clause, args := queries.BuildSubquery(q)
return withQueryMod{
alias: alias,
clause: clause,
args: args,
}
}

type selectQueryMod struct {
columns []string
}
Expand Down
12 changes: 9 additions & 3 deletions queries/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type Query struct {

delete bool
update map[string]interface{}
withs []argClause
withs []with
selectCols []string
count bool
from []string
Expand Down Expand Up @@ -88,6 +88,12 @@ type argClause struct {
args []interface{}
}

type with struct {
alias string
clause string
args []interface{}
}

type rawSQL struct {
sql string
args []interface{}
Expand Down Expand Up @@ -398,8 +404,8 @@ func AppendOrderBy(q *Query, clause string, args ...interface{}) {
}

// AppendWith on the query.
func AppendWith(q *Query, clause string, args ...interface{}) {
q.withs = append(q.withs, argClause{clause: clause, args: args})
func AppendWith(q *Query, alias, clause string, args ...interface{}) {
q.withs = append(q.withs, with{alias: alias, clause: clause, args: args})
}

// RemoveSoftDeleteWhere prevents the automatic soft delete where clause
Expand Down
27 changes: 21 additions & 6 deletions queries/query_builders.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,18 @@ var (
// and it's accompanying arguments. Using this method
// allows query building without immediate execution.
func BuildQuery(q *Query) (string, []interface{}) {
return buildQuery(q, true)
}

// BuildSubquery builds a query object into the query string
// and it's accompanying arguments but doesn't append a
// semi-colon allowing the resulting string to be embedded
// as a subquery.
func BuildSubquery(q *Query) (string, []interface{}) {
return buildQuery(q, false)
}

func buildQuery(q *Query, finalize bool) (string, []interface{}) {
var buf *bytes.Buffer
var args []interface{}

Expand All @@ -36,6 +48,10 @@ func BuildQuery(q *Query) (string, []interface{}) {
buf, args = buildSelectQuery(q)
}

if finalize {
buf.WriteByte(';')
}

defer strmangle.PutBuffer(buf)

// Cache the generated query for query object re-use
Expand Down Expand Up @@ -133,7 +149,6 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {

writeModifiers(q, buf, &args)

buf.WriteByte(';')
return buf, args
}

Expand All @@ -155,8 +170,6 @@ func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {

writeModifiers(q, buf, &args)

buf.WriteByte(';')

return buf, args
}

Expand Down Expand Up @@ -199,8 +212,6 @@ func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {

writeModifiers(q, buf, &args)

buf.WriteByte(';')

return buf, args
}

Expand Down Expand Up @@ -614,7 +625,11 @@ func writeCTEs(q *Query, buf *bytes.Buffer, args *[]interface{}) {
withBuf := strmangle.GetBuffer()
lastPos := len(q.withs) - 1
for i, w := range q.withs {
fmt.Fprintf(withBuf, " %s", w.clause)
if w.alias != "" {
fmt.Fprintf(withBuf, " %s AS (%s)", w.alias, w.clause)
} else {
fmt.Fprintf(withBuf, " %s", w.clause)
}
if i >= 0 && i < lastPos {
withBuf.WriteByte(',')
}
Expand Down
37 changes: 31 additions & 6 deletions queries/query_builders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ func TestBuildQuery(t *testing.T) {
{&Query{from: []string{"cats as c", "dogs as d"}, joins: []join{{JoinOuterFull, "dogs d on d.cat_id = cats.id", nil}}}, nil},
{&Query{
from: []string{"t"},
withs: []argClause{
{"cte_0 AS (SELECT * FROM other_t0)", nil},
{"cte_1 AS (SELECT * FROM other_t1 WHERE thing=? AND stuff=?)", []interface{}{3, 7}},
withs: []with{
{"cte_0", "SELECT * FROM other_t0", nil},
{"cte_1", "SELECT * FROM other_t1 WHERE thing=? AND stuff=?", []interface{}{3, 7}},
},
}, []interface{}{3, 7},
},
Expand Down Expand Up @@ -161,6 +161,31 @@ func TestBuildQuery(t *testing.T) {
}
}

func TestBuildSubquery(t *testing.T) {
t.Parallel()

q1 := &Query{}
SetSelect(q1, []string{"foo", "bar"})
SetFrom(q1, "tbl")
q1.dialect = &drivers.Dialect{LQ: '"', RQ: '"', UseIndexPlaceholders: true}

q2 := &Query{}
SetSelect(q2, []string{"foo", "bar"})
SetFrom(q2, "tbl")
q2.dialect = &drivers.Dialect{LQ: '"', RQ: '"', UseIndexPlaceholders: true}

query, _ := BuildQuery(q1)
subquery, _ := BuildSubquery(q2)

if !strings.HasSuffix(query, ";") {
t.Error("BuildQuery() result is missing trailing ';'\n", query)
}

if strings.HasSuffix(subquery, ";") {
t.Error("BuildSubquery() result has trailing ';'\n", subquery)
}
}

func TestWriteStars(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -516,11 +541,11 @@ func TestLimitClause(t *testing.T) {
t.Parallel()

tests := []struct {
limit *int
limit *int
expectPredicate func(sql string) bool
}{
{nil, func(sql string) bool {
return !strings.Contains(sql,"LIMIT")
return !strings.Contains(sql, "LIMIT")
}},
{newIntPtr(0), func(sql string) bool {
return strings.Contains(sql, "LIMIT 0")
Expand All @@ -532,7 +557,7 @@ func TestLimitClause(t *testing.T) {

for i, test := range tests {
q := &Query{
limit: test.limit,
limit: test.limit,
dialect: &drivers.Dialect{LQ: '"', RQ: '"', UseIndexPlaceholders: true, UseTopClause: false},
}
sql, _ := BuildQuery(q)
Expand Down
15 changes: 8 additions & 7 deletions queries/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -604,17 +604,17 @@ func TestAppendWith(t *testing.T) {
t.Parallel()

q := &Query{}
AppendWith(q, "cte_0 AS (SELECT * FROM table_0 WHERE thing=$1 AND stuff=$2)", 5, 10)
AppendWith(q, "cte_1 AS (SELECT * FROM table_1 WHERE thing=$1 AND stuff=$2)", 5, 10)
AppendWith(q, "cte_0", "SELECT * FROM table_0 WHERE thing=$1 AND stuff=$2", 5, 10)
AppendWith(q, "cte_1", "SELECT * FROM table_1 WHERE thing=$1 AND stuff=$2", 5, 10)

if len(q.withs) != 2 {
t.Errorf("Expected len 2, got %d", len(q.withs))
}

if q.withs[0].clause != "cte_0 AS (SELECT * FROM table_0 WHERE thing=$1 AND stuff=$2)" {
if q.withs[0].alias != "cte_0" || q.withs[0].clause != "SELECT * FROM table_0 WHERE thing=$1 AND stuff=$2" {
t.Errorf("Got invalid with on string: %#v", q.withs)
}
if q.withs[1].clause != "cte_1 AS (SELECT * FROM table_1 WHERE thing=$1 AND stuff=$2)" {
if q.withs[1].alias != "cte_1" || q.withs[1].clause != "SELECT * FROM table_1 WHERE thing=$1 AND stuff=$2" {
t.Errorf("Got invalid with on string: %#v", q.withs)
}

Expand All @@ -629,16 +629,17 @@ func TestAppendWith(t *testing.T) {
t.Errorf("Invalid args values, got %#v", q.withs[0].args)
}

q.withs = []argClause{{
clause: "other_cte AS (SELECT * FROM other_table WHERE thing=$1 AND stuff=$2)",
q.withs = []with{{
alias: "other_cte",
clause: "SELECT * FROM other_table WHERE thing=$1 AND stuff=$2",
args: []interface{}{3, 7},
}}

if len(q.withs) != 1 {
t.Errorf("Expected len 1, got %d", len(q.withs))
}

if q.withs[0].clause != "other_cte AS (SELECT * FROM other_table WHERE thing=$1 AND stuff=$2)" {
if q.withs[0].alias != "other_cte" || q.withs[0].clause != "SELECT * FROM other_table WHERE thing=$1 AND stuff=$2" {
t.Errorf("Got invalid with on string: %#v", q.withs)
}
}
Expand Down