123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621 |
- // Copyright (c) 2012-present The upper.io/db authors. All rights reserved.
- //
- // Permission is hereby granted, free of charge, to any person obtaining
- // a copy of this software and associated documentation files (the
- // "Software"), to deal in the Software without restriction, including
- // without limitation the rights to use, copy, modify, merge, publish,
- // distribute, sublicense, and/or sell copies of the Software, and to
- // permit persons to whom the Software is furnished to do so, subject to
- // the following conditions:
- //
- // The above copyright notice and this permission notice shall be
- // included in all copies or substantial portions of the Software.
- //
- // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
- // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
- // LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
- // OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
- // WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
- // Package sqlbuilder provides tools for building custom SQL queries.
- package sqlbuilder
-
- import (
- "context"
- "database/sql"
- "errors"
- "fmt"
- "log"
- "reflect"
- "regexp"
- "sort"
- "strconv"
- "strings"
-
- "upper.io/db.v3"
- "upper.io/db.v3/internal/sqladapter/compat"
- "upper.io/db.v3/internal/sqladapter/exql"
- "upper.io/db.v3/lib/reflectx"
- )
-
- // MapOptions represents options for the mapper.
- type MapOptions struct {
- IncludeZeroed bool
- IncludeNil bool
- }
-
- var defaultMapOptions = MapOptions{
- IncludeZeroed: false,
- IncludeNil: false,
- }
-
- type compilable interface {
- Compile() (string, error)
- Arguments() []interface{}
- }
-
- type hasIsZero interface {
- IsZero() bool
- }
-
- type hasArguments interface {
- Arguments() []interface{}
- }
-
- type hasStatement interface {
- statement() *exql.Statement
- }
-
- type iterator struct {
- sess exprDB
- cursor *sql.Rows // This is the main query cursor. It starts as a nil value.
- err error
- }
-
- type fieldValue struct {
- fields []string
- values []interface{}
- }
-
- var (
- reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`)
- reColumnCompareExclude = regexp.MustCompile(`[^a-zA-Z0-9]`)
- )
-
- var (
- sqlPlaceholder = exql.RawValue(`?`)
- )
-
- var (
- errDeprecatedJSONBTag = errors.New(`Tag "jsonb" is deprecated. See "PostgreSQL: jsonb tag" at https://github.com/upper/db/releases/tag/v3.4.0`)
- )
-
- type exprDB interface {
- StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (sql.Result, error)
- StatementPrepare(ctx context.Context, stmt *exql.Statement) (*sql.Stmt, error)
- StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Rows, error)
- StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Row, error)
-
- Context() context.Context
- }
-
- type sqlBuilder struct {
- sess exprDB
- t *templateWithUtils
- }
-
- // WithSession returns a query builder that is bound to the given database session.
- func WithSession(sess interface{}, t *exql.Template) SQLBuilder {
- if sqlDB, ok := sess.(*sql.DB); ok {
- sess = sqlDB
- }
- return &sqlBuilder{
- sess: sess.(exprDB), // Let it panic, it will show the developer an informative error.
- t: newTemplateWithUtils(t),
- }
- }
-
- // WithTemplate returns a builder that is based on the given template.
- func WithTemplate(t *exql.Template) SQLBuilder {
- return &sqlBuilder{
- t: newTemplateWithUtils(t),
- }
- }
-
- // NewIterator creates an iterator using the given *sql.Rows.
- func NewIterator(rows *sql.Rows) Iterator {
- return &iterator{nil, rows, nil}
- }
-
- func (b *sqlBuilder) Iterator(query interface{}, args ...interface{}) Iterator {
- return b.IteratorContext(b.sess.Context(), query, args...)
- }
-
- func (b *sqlBuilder) IteratorContext(ctx context.Context, query interface{}, args ...interface{}) Iterator {
- rows, err := b.QueryContext(ctx, query, args...)
- return &iterator{b.sess, rows, err}
- }
-
- func (b *sqlBuilder) Prepare(query interface{}) (*sql.Stmt, error) {
- return b.PrepareContext(b.sess.Context(), query)
- }
-
- func (b *sqlBuilder) PrepareContext(ctx context.Context, query interface{}) (*sql.Stmt, error) {
- switch q := query.(type) {
- case *exql.Statement:
- return b.sess.StatementPrepare(ctx, q)
- case string:
- return b.sess.StatementPrepare(ctx, exql.RawSQL(q))
- case db.RawValue:
- return b.PrepareContext(ctx, q.Raw())
- default:
- return nil, fmt.Errorf("unsupported query type %T", query)
- }
- }
-
- func (b *sqlBuilder) Exec(query interface{}, args ...interface{}) (sql.Result, error) {
- return b.ExecContext(b.sess.Context(), query, args...)
- }
-
- func (b *sqlBuilder) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (sql.Result, error) {
- switch q := query.(type) {
- case *exql.Statement:
- return b.sess.StatementExec(ctx, q, args...)
- case string:
- return b.sess.StatementExec(ctx, exql.RawSQL(q), args...)
- case db.RawValue:
- return b.ExecContext(ctx, q.Raw(), q.Arguments()...)
- default:
- return nil, fmt.Errorf("unsupported query type %T", query)
- }
- }
-
- func (b *sqlBuilder) Query(query interface{}, args ...interface{}) (*sql.Rows, error) {
- return b.QueryContext(b.sess.Context(), query, args...)
- }
-
- func (b *sqlBuilder) QueryContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Rows, error) {
- switch q := query.(type) {
- case *exql.Statement:
- return b.sess.StatementQuery(ctx, q, args...)
- case string:
- return b.sess.StatementQuery(ctx, exql.RawSQL(q), args...)
- case db.RawValue:
- return b.QueryContext(ctx, q.Raw(), q.Arguments()...)
- default:
- return nil, fmt.Errorf("unsupported query type %T", query)
- }
- }
-
- func (b *sqlBuilder) QueryRow(query interface{}, args ...interface{}) (*sql.Row, error) {
- return b.QueryRowContext(b.sess.Context(), query, args...)
- }
-
- func (b *sqlBuilder) QueryRowContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Row, error) {
- switch q := query.(type) {
- case *exql.Statement:
- return b.sess.StatementQueryRow(ctx, q, args...)
- case string:
- return b.sess.StatementQueryRow(ctx, exql.RawSQL(q), args...)
- case db.RawValue:
- return b.QueryRowContext(ctx, q.Raw(), q.Arguments()...)
- default:
- return nil, fmt.Errorf("unsupported query type %T", query)
- }
- }
-
- func (b *sqlBuilder) SelectFrom(table ...interface{}) Selector {
- qs := &selector{
- builder: b,
- }
- return qs.From(table...)
- }
-
- func (b *sqlBuilder) Select(columns ...interface{}) Selector {
- qs := &selector{
- builder: b,
- }
- return qs.Columns(columns...)
- }
-
- func (b *sqlBuilder) InsertInto(table string) Inserter {
- qi := &inserter{
- builder: b,
- }
- return qi.Into(table)
- }
-
- func (b *sqlBuilder) DeleteFrom(table string) Deleter {
- qd := &deleter{
- builder: b,
- }
- return qd.setTable(table)
- }
-
- func (b *sqlBuilder) Update(table string) Updater {
- qu := &updater{
- builder: b,
- }
- return qu.setTable(table)
- }
-
- // Map receives a pointer to map or struct and maps it to columns and values.
- func Map(item interface{}, options *MapOptions) ([]string, []interface{}, error) {
- var fv fieldValue
- if options == nil {
- options = &defaultMapOptions
- }
-
- itemV := reflect.ValueOf(item)
- if !itemV.IsValid() {
- return nil, nil, nil
- }
-
- itemT := itemV.Type()
-
- if itemT.Kind() == reflect.Ptr {
- // Single dereference. Just in case the user passes a pointer to struct
- // instead of a struct.
- item = itemV.Elem().Interface()
- itemV = reflect.ValueOf(item)
- itemT = itemV.Type()
- }
-
- switch itemT.Kind() {
- case reflect.Struct:
- fieldMap := mapper.TypeMap(itemT).Names
- nfields := len(fieldMap)
-
- fv.values = make([]interface{}, 0, nfields)
- fv.fields = make([]string, 0, nfields)
-
- for _, fi := range fieldMap {
-
- // Check for deprecated JSONB tag
- if _, hasJSONBTag := fi.Options["jsonb"]; hasJSONBTag {
- return nil, nil, errDeprecatedJSONBTag
- }
-
- // Field options
- _, tagOmitEmpty := fi.Options["omitempty"]
-
- fld := reflectx.FieldByIndexesReadOnly(itemV, fi.Index)
- if fld.Kind() == reflect.Ptr && fld.IsNil() {
- if tagOmitEmpty && !options.IncludeNil {
- continue
- }
- fv.fields = append(fv.fields, fi.Name)
- if tagOmitEmpty {
- fv.values = append(fv.values, sqlDefault)
- } else {
- fv.values = append(fv.values, nil)
- }
- continue
- }
-
- value := fld.Interface()
-
- isZero := false
- if t, ok := fld.Interface().(hasIsZero); ok {
- if t.IsZero() {
- isZero = true
- }
- } else if fld.Kind() == reflect.Array || fld.Kind() == reflect.Slice {
- if fld.Len() == 0 {
- isZero = true
- }
- } else if reflect.DeepEqual(fi.Zero.Interface(), value) {
- isZero = true
- }
-
- if isZero && tagOmitEmpty && !options.IncludeZeroed {
- continue
- }
-
- fv.fields = append(fv.fields, fi.Name)
- v, err := marshal(value)
- if err != nil {
- return nil, nil, err
- }
- if isZero && tagOmitEmpty {
- v = sqlDefault
- }
- fv.values = append(fv.values, v)
- }
-
- case reflect.Map:
- nfields := itemV.Len()
- fv.values = make([]interface{}, nfields)
- fv.fields = make([]string, nfields)
- mkeys := itemV.MapKeys()
-
- for i, keyV := range mkeys {
- valv := itemV.MapIndex(keyV)
- fv.fields[i] = fmt.Sprintf("%v", keyV.Interface())
-
- v, err := marshal(valv.Interface())
- if err != nil {
- return nil, nil, err
- }
-
- fv.values[i] = v
- }
- default:
- return nil, nil, ErrExpectingPointerToEitherMapOrStruct
- }
-
- sort.Sort(&fv)
-
- return fv.fields, fv.values, nil
- }
-
- func extractArguments(fragments []interface{}) []interface{} {
- args := []interface{}{}
- l := len(fragments)
- for i := 0; i < l; i++ {
- switch v := fragments[i].(type) {
- case hasArguments: // TODO: use this on other places where we want to extract arguments.
- args = append(args, v.Arguments()...)
- }
- }
- return args
- }
-
- func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, error) {
- l := len(columns)
- f := make([]exql.Fragment, l)
- args := []interface{}{}
-
- for i := 0; i < l; i++ {
- switch v := columns[i].(type) {
- case compilable:
- c, err := v.Compile()
- if err != nil {
- return nil, nil, err
- }
- q, a := Preprocess(c, v.Arguments())
- if _, ok := v.(Selector); ok {
- q = "(" + q + ")"
- }
- f[i] = exql.RawValue(q)
- args = append(args, a...)
- case db.Function:
- fnName, fnArgs := v.Name(), v.Arguments()
- if len(fnArgs) == 0 {
- fnName = fnName + "()"
- } else {
- fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")"
- }
- fnName, fnArgs = Preprocess(fnName, fnArgs)
- f[i] = exql.RawValue(fnName)
- args = append(args, fnArgs...)
- case db.RawValue:
- q, a := Preprocess(v.Raw(), v.Arguments())
- f[i] = exql.RawValue(q)
- args = append(args, a...)
- case exql.Fragment:
- f[i] = v
- case string:
- f[i] = exql.ColumnWithName(v)
- case int:
- f[i] = exql.RawValue(fmt.Sprintf("%v", v))
- case interface{}:
- f[i] = exql.ColumnWithName(fmt.Sprintf("%v", v))
- default:
- return nil, nil, fmt.Errorf("unexpected argument type %T for Select() argument", v)
- }
- }
- return f, args, nil
- }
-
- func prepareQueryForDisplay(in string) (out string) {
- j := 1
- for i := range in {
- if in[i] == '?' {
- out = out + "$" + strconv.Itoa(j)
- j++
- } else {
- out = out + string(in[i])
- }
- }
-
- out = reInvisibleChars.ReplaceAllString(out, ` `)
- return strings.TrimSpace(out)
- }
-
- func (iter *iterator) NextScan(dst ...interface{}) error {
- if ok := iter.Next(); ok {
- return iter.Scan(dst...)
- }
- if err := iter.Err(); err != nil {
- return err
- }
- return db.ErrNoMoreRows
- }
-
- func (iter *iterator) ScanOne(dst ...interface{}) error {
- defer iter.Close()
- return iter.NextScan(dst...)
- }
-
- func (iter *iterator) Scan(dst ...interface{}) error {
- if err := iter.Err(); err != nil {
- return err
- }
- return iter.cursor.Scan(dst...)
- }
-
- func (iter *iterator) setErr(err error) error {
- iter.err = err
- return iter.err
- }
-
- func (iter *iterator) One(dst interface{}) error {
- if err := iter.Err(); err != nil {
- return err
- }
- defer iter.Close()
- return iter.setErr(iter.next(dst))
- }
-
- func (iter *iterator) All(dst interface{}) error {
- if err := iter.Err(); err != nil {
- return err
- }
- defer iter.Close()
-
- // Fetching all results within the cursor.
- if err := fetchRows(iter, dst); err != nil {
- return iter.setErr(err)
- }
-
- return nil
- }
-
- func (iter *iterator) Err() (err error) {
- return iter.err
- }
-
- func (iter *iterator) Next(dst ...interface{}) bool {
- if err := iter.Err(); err != nil {
- return false
- }
-
- if err := iter.next(dst...); err != nil {
- // ignore db.ErrNoMoreRows, just break.
- if err != db.ErrNoMoreRows {
- iter.setErr(err)
- }
- return false
- }
-
- return true
- }
-
- func (iter *iterator) next(dst ...interface{}) error {
- if iter.cursor == nil {
- return iter.setErr(db.ErrNoMoreRows)
- }
-
- switch len(dst) {
- case 0:
- if ok := iter.cursor.Next(); !ok {
- defer iter.Close()
- err := iter.cursor.Err()
- if err == nil {
- err = db.ErrNoMoreRows
- }
- return err
- }
- return nil
- case 1:
- if err := fetchRow(iter, dst[0]); err != nil {
- defer iter.Close()
- return err
- }
- return nil
- }
-
- return errors.New("Next does not currently supports more than one parameters")
- }
-
- func (iter *iterator) Close() (err error) {
- if iter.cursor != nil {
- err = iter.cursor.Close()
- iter.cursor = nil
- }
- return err
- }
-
- func marshal(v interface{}) (interface{}, error) {
- if m, isMarshaler := v.(db.Marshaler); isMarshaler {
- var err error
- if v, err = m.MarshalDB(); err != nil {
- return nil, err
- }
- }
- return v, nil
- }
-
- func (fv *fieldValue) Len() int {
- return len(fv.fields)
- }
-
- func (fv *fieldValue) Swap(i, j int) {
- fv.fields[i], fv.fields[j] = fv.fields[j], fv.fields[i]
- fv.values[i], fv.values[j] = fv.values[j], fv.values[i]
- }
-
- func (fv *fieldValue) Less(i, j int) bool {
- return fv.fields[i] < fv.fields[j]
- }
-
- type exprProxy struct {
- db *sql.DB
- t *exql.Template
- }
-
- func newSqlgenProxy(db *sql.DB, t *exql.Template) *exprProxy {
- return &exprProxy{db: db, t: t}
- }
-
- func (p *exprProxy) Context() context.Context {
- log.Printf("Missing context")
- return context.Background()
- }
-
- func (p *exprProxy) StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (sql.Result, error) {
- s, err := stmt.Compile(p.t)
- if err != nil {
- return nil, err
- }
- return compat.ExecContext(p.db, ctx, s, args)
- }
-
- func (p *exprProxy) StatementPrepare(ctx context.Context, stmt *exql.Statement) (*sql.Stmt, error) {
- s, err := stmt.Compile(p.t)
- if err != nil {
- return nil, err
- }
- return compat.PrepareContext(p.db, ctx, s)
- }
-
- func (p *exprProxy) StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Rows, error) {
- s, err := stmt.Compile(p.t)
- if err != nil {
- return nil, err
- }
- return compat.QueryContext(p.db, ctx, s, args)
- }
-
- func (p *exprProxy) StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Row, error) {
- s, err := stmt.Compile(p.t)
- if err != nil {
- return nil, err
- }
- return compat.QueryRowContext(p.db, ctx, s, args), nil
- }
-
- var (
- _ = SQLBuilder(&sqlBuilder{})
- _ = exprDB(&exprProxy{})
- )
-
- func joinArguments(args ...[]interface{}) []interface{} {
- total := 0
- for i := range args {
- total += len(args[i])
- }
- if total == 0 {
- return nil
- }
-
- flatten := make([]interface{}, 0, total)
- for i := range args {
- flatten = append(flatten, args[i]...)
- }
- return flatten
- }
|