123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526 |
- package sqlbuilder
-
- import (
- "context"
- "database/sql"
- "errors"
- "fmt"
- "strings"
-
- "upper.io/db.v3"
- "upper.io/db.v3/internal/immutable"
- "upper.io/db.v3/internal/sqladapter/exql"
- )
-
- type selectorQuery struct {
- table *exql.Columns
- tableArgs []interface{}
-
- as string
-
- distinct bool
-
- where *exql.Where
- whereArgs []interface{}
-
- groupBy *exql.GroupBy
- groupByArgs []interface{}
-
- orderBy *exql.OrderBy
- orderByArgs []interface{}
-
- limit exql.Limit
- offset exql.Offset
-
- columns *exql.Columns
- columnsArgs []interface{}
-
- joins []*exql.Join
- joinsArgs []interface{}
-
- amendFn func(string) string
- }
-
- func (sq *selectorQuery) and(b *sqlBuilder, terms ...interface{}) error {
- where, whereArgs := b.t.toWhereWithArguments(terms)
-
- if sq.where == nil {
- sq.where, sq.whereArgs = &exql.Where{}, []interface{}{}
- }
- sq.where.Append(&where)
- sq.whereArgs = append(sq.whereArgs, whereArgs...)
-
- return nil
- }
-
- func (sq *selectorQuery) arguments() []interface{} {
- return joinArguments(
- sq.columnsArgs,
- sq.tableArgs,
- sq.joinsArgs,
- sq.whereArgs,
- sq.groupByArgs,
- sq.orderByArgs,
- )
- }
-
- func (sq *selectorQuery) statement() *exql.Statement {
- stmt := &exql.Statement{
- Type: exql.Select,
- Table: sq.table,
- Columns: sq.columns,
- Distinct: sq.distinct,
- Limit: sq.limit,
- Offset: sq.offset,
- Where: sq.where,
- OrderBy: sq.orderBy,
- GroupBy: sq.groupBy,
- }
-
- if len(sq.joins) > 0 {
- stmt.Joins = exql.JoinConditions(sq.joins...)
- }
-
- stmt.SetAmendment(sq.amendFn)
-
- return stmt
- }
-
- func (sq *selectorQuery) pushJoin(t string, tables []interface{}) error {
- fragments, args, err := columnFragments(tables)
- if err != nil {
- return err
- }
-
- if sq.joins == nil {
- sq.joins = []*exql.Join{}
- }
- sq.joins = append(sq.joins,
- &exql.Join{
- Type: t,
- Table: exql.JoinColumns(fragments...),
- },
- )
-
- sq.joinsArgs = append(sq.joinsArgs, args...)
-
- return nil
- }
-
- type selector struct {
- builder *sqlBuilder
-
- fn func(*selectorQuery) error
- prev *selector
- }
-
- var _ = immutable.Immutable(&selector{})
-
- func (sel *selector) SQLBuilder() *sqlBuilder {
- if sel.prev == nil {
- return sel.builder
- }
- return sel.prev.SQLBuilder()
- }
-
- func (sel *selector) String() string {
- s, err := sel.Compile()
- if err != nil {
- panic(err.Error())
- }
- return prepareQueryForDisplay(s)
- }
-
- func (sel *selector) frame(fn func(*selectorQuery) error) *selector {
- return &selector{prev: sel, fn: fn}
- }
-
- func (sel *selector) clone() Selector {
- return sel.frame(func(*selectorQuery) error {
- return nil
- })
- }
-
- func (sel *selector) From(tables ...interface{}) Selector {
- return sel.frame(
- func(sq *selectorQuery) error {
- fragments, args, err := columnFragments(tables)
- if err != nil {
- return err
- }
- sq.table = exql.JoinColumns(fragments...)
- sq.tableArgs = args
- return nil
- },
- )
- }
-
- func (sel *selector) setColumns(columns ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- sq.columns = nil
- return sq.pushColumns(columns...)
- })
- }
-
- func (sel *selector) Columns(columns ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- return sq.pushColumns(columns...)
- })
- }
-
- func (sq *selectorQuery) pushColumns(columns ...interface{}) error {
- f, args, err := columnFragments(columns)
- if err != nil {
- return err
- }
-
- c := exql.JoinColumns(f...)
-
- if sq.columns != nil {
- sq.columns.Append(c)
- } else {
- sq.columns = c
- }
-
- sq.columnsArgs = append(sq.columnsArgs, args...)
- return nil
- }
-
- func (sel *selector) Distinct(exps ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- sq.distinct = true
- return sq.pushColumns(exps...)
- })
- }
-
- func (sel *selector) Where(terms ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- if len(terms) == 1 && terms[0] == nil {
- sq.where, sq.whereArgs = &exql.Where{}, []interface{}{}
- return nil
- }
- return sq.and(sel.SQLBuilder(), terms...)
- })
- }
-
- func (sel *selector) And(terms ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- return sq.and(sel.SQLBuilder(), terms...)
- })
- }
-
- func (sel *selector) Amend(fn func(string) string) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- sq.amendFn = fn
- return nil
- })
- }
-
- func (sel *selector) Arguments() []interface{} {
- sq, err := sel.build()
- if err != nil {
- return nil
- }
- return sq.arguments()
- }
-
- func (sel *selector) GroupBy(columns ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- fragments, args, err := columnFragments(columns)
- if err != nil {
- return err
- }
-
- if fragments != nil {
- sq.groupBy = exql.GroupByColumns(fragments...)
- }
- sq.groupByArgs = args
-
- return nil
- })
- }
-
- func (sel *selector) OrderBy(columns ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
-
- if len(columns) == 1 && columns[0] == nil {
- sq.orderBy = nil
- sq.orderByArgs = nil
- return nil
- }
-
- var sortColumns exql.SortColumns
-
- for i := range columns {
- var sort *exql.SortColumn
-
- switch value := columns[i].(type) {
- case db.RawValue:
- query, args := Preprocess(value.Raw(), value.Arguments())
- sort = &exql.SortColumn{
- Column: exql.RawValue(query),
- }
- sq.orderByArgs = append(sq.orderByArgs, args...)
- case db.Function:
- fnName, fnArgs := value.Name(), value.Arguments()
- if len(fnArgs) == 0 {
- fnName = fnName + "()"
- } else {
- fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
- }
- fnName, fnArgs = Preprocess(fnName, fnArgs)
- sort = &exql.SortColumn{
- Column: exql.RawValue(fnName),
- }
- sq.orderByArgs = append(sq.orderByArgs, fnArgs...)
- case string:
- if strings.HasPrefix(value, "-") {
- sort = &exql.SortColumn{
- Column: exql.ColumnWithName(value[1:]),
- Order: exql.Descendent,
- }
- } else {
- chunks := strings.SplitN(value, " ", 2)
-
- order := exql.Ascendent
- if len(chunks) > 1 && strings.ToUpper(chunks[1]) == "DESC" {
- order = exql.Descendent
- }
-
- sort = &exql.SortColumn{
- Column: exql.ColumnWithName(chunks[0]),
- Order: order,
- }
- }
- default:
- return fmt.Errorf("Can't sort by type %T", value)
- }
- sortColumns.Columns = append(sortColumns.Columns, sort)
- }
-
- sq.orderBy = &exql.OrderBy{
- SortColumns: &sortColumns,
- }
- return nil
- })
- }
-
- func (sel *selector) Using(columns ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
-
- joins := len(sq.joins)
- if joins == 0 {
- return errors.New(`cannot use Using() without a preceding Join() expression`)
- }
-
- lastJoin := sq.joins[joins-1]
- if lastJoin.On != nil {
- return errors.New(`cannot use Using() and On() with the same Join() expression`)
- }
-
- fragments, args, err := columnFragments(columns)
- if err != nil {
- return err
- }
-
- sq.joinsArgs = append(sq.joinsArgs, args...)
- lastJoin.Using = exql.UsingColumns(fragments...)
-
- return nil
- })
- }
-
- func (sel *selector) FullJoin(tables ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- return sq.pushJoin("FULL", tables)
- })
- }
-
- func (sel *selector) CrossJoin(tables ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- return sq.pushJoin("CROSS", tables)
- })
- }
-
- func (sel *selector) RightJoin(tables ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- return sq.pushJoin("RIGHT", tables)
- })
- }
-
- func (sel *selector) LeftJoin(tables ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- return sq.pushJoin("LEFT", tables)
- })
- }
-
- func (sel *selector) Join(tables ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- return sq.pushJoin("", tables)
- })
- }
-
- func (sel *selector) On(terms ...interface{}) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- joins := len(sq.joins)
-
- if joins == 0 {
- return errors.New(`cannot use On() without a preceding Join() expression`)
- }
-
- lastJoin := sq.joins[joins-1]
- if lastJoin.On != nil {
- return errors.New(`cannot use Using() and On() with the same Join() expression`)
- }
-
- w, a := sel.SQLBuilder().t.toWhereWithArguments(terms)
- o := exql.On(w)
-
- lastJoin.On = &o
-
- sq.joinsArgs = append(sq.joinsArgs, a...)
-
- return nil
- })
- }
-
- func (sel *selector) Limit(n int) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- if n < 0 {
- n = 0
- }
- sq.limit = exql.Limit(n)
- return nil
- })
- }
-
- func (sel *selector) Offset(n int) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- if n < 0 {
- n = 0
- }
- sq.offset = exql.Offset(n)
- return nil
- })
- }
-
- func (sel *selector) template() *exql.Template {
- return sel.SQLBuilder().t.Template
- }
-
- func (sel *selector) As(alias string) Selector {
- return sel.frame(func(sq *selectorQuery) error {
- if sq.table == nil {
- return errors.New("Cannot use As() without a preceding From() expression")
- }
- last := len(sq.table.Columns) - 1
- if raw, ok := sq.table.Columns[last].(*exql.Raw); ok {
- compiled, err := exql.ColumnWithName(alias).Compile(sel.template())
- if err != nil {
- return err
- }
- sq.table.Columns[last] = exql.RawValue(raw.Value + " AS " + compiled)
- }
- return nil
- })
- }
-
- func (sel *selector) statement() *exql.Statement {
- sq, _ := sel.build()
- return sq.statement()
- }
-
- func (sel *selector) QueryRow() (*sql.Row, error) {
- return sel.QueryRowContext(sel.SQLBuilder().sess.Context())
- }
-
- func (sel *selector) QueryRowContext(ctx context.Context) (*sql.Row, error) {
- sq, err := sel.build()
- if err != nil {
- return nil, err
- }
-
- return sel.SQLBuilder().sess.StatementQueryRow(ctx, sq.statement(), sq.arguments()...)
- }
-
- func (sel *selector) Prepare() (*sql.Stmt, error) {
- return sel.PrepareContext(sel.SQLBuilder().sess.Context())
- }
-
- func (sel *selector) PrepareContext(ctx context.Context) (*sql.Stmt, error) {
- sq, err := sel.build()
- if err != nil {
- return nil, err
- }
- return sel.SQLBuilder().sess.StatementPrepare(ctx, sq.statement())
- }
-
- func (sel *selector) Query() (*sql.Rows, error) {
- return sel.QueryContext(sel.SQLBuilder().sess.Context())
- }
-
- func (sel *selector) QueryContext(ctx context.Context) (*sql.Rows, error) {
- sq, err := sel.build()
- if err != nil {
- return nil, err
- }
- return sel.SQLBuilder().sess.StatementQuery(ctx, sq.statement(), sq.arguments()...)
- }
-
- func (sel *selector) Iterator() Iterator {
- return sel.IteratorContext(sel.SQLBuilder().sess.Context())
- }
-
- func (sel *selector) IteratorContext(ctx context.Context) Iterator {
- sess := sel.SQLBuilder().sess
- sq, err := sel.build()
- if err != nil {
- return &iterator{sess, nil, err}
- }
-
- rows, err := sess.StatementQuery(ctx, sq.statement(), sq.arguments()...)
- return &iterator{sess, rows, err}
- }
-
- func (sel *selector) Paginate(pageSize uint) Paginator {
- return newPaginator(sel.clone(), pageSize)
- }
-
- func (sel *selector) All(destSlice interface{}) error {
- return sel.Iterator().All(destSlice)
- }
-
- func (sel *selector) One(dest interface{}) error {
- return sel.Iterator().One(dest)
- }
-
- func (sel *selector) build() (*selectorQuery, error) {
- sq, err := immutable.FastForward(sel)
- if err != nil {
- return nil, err
- }
- return sq.(*selectorQuery), nil
- }
-
- func (sel *selector) Compile() (string, error) {
- return sel.statement().Compile(sel.template())
- }
-
- func (sel *selector) Prev() immutable.Immutable {
- if sel == nil {
- return nil
- }
- return sel.prev
- }
-
- func (sel *selector) Fn(in interface{}) error {
- if sel.fn == nil {
- return nil
- }
- return sel.fn(in.(*selectorQuery))
- }
-
- func (sel *selector) Base() interface{} {
- return &selectorQuery{}
- }
|