123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- package sqlbuilder
-
- import (
- "context"
- "database/sql"
-
- "upper.io/db.v3/internal/immutable"
- "upper.io/db.v3/internal/sqladapter/exql"
- )
-
- type updaterQuery struct {
- table string
-
- columnValues *exql.ColumnValues
- columnValuesArgs []interface{}
-
- limit int
-
- where *exql.Where
- whereArgs []interface{}
-
- err error
-
- amendFn func(string) string
- }
-
- func (uq *updaterQuery) and(b *sqlBuilder, terms ...interface{}) error {
- where, whereArgs := b.t.toWhereWithArguments(terms)
-
- if uq.where == nil {
- uq.where, uq.whereArgs = &exql.Where{}, []interface{}{}
- }
- uq.where.Append(&where)
- uq.whereArgs = append(uq.whereArgs, whereArgs...)
-
- return nil
- }
-
- func (uq *updaterQuery) statement() *exql.Statement {
- stmt := &exql.Statement{
- Type: exql.Update,
- Table: exql.TableWithName(uq.table),
- ColumnValues: uq.columnValues,
- }
-
- if uq.where != nil {
- stmt.Where = uq.where
- }
-
- if uq.limit != 0 {
- stmt.Limit = exql.Limit(uq.limit)
- }
-
- stmt.SetAmendment(uq.amendFn)
-
- return stmt
- }
-
- func (uq *updaterQuery) arguments() []interface{} {
- return joinArguments(
- uq.columnValuesArgs,
- uq.whereArgs,
- )
- }
-
- type updater struct {
- builder *sqlBuilder
-
- fn func(*updaterQuery) error
- prev *updater
- }
-
- var _ = immutable.Immutable(&updater{})
-
- func (upd *updater) SQLBuilder() *sqlBuilder {
- if upd.prev == nil {
- return upd.builder
- }
- return upd.prev.SQLBuilder()
- }
-
- func (upd *updater) template() *exql.Template {
- return upd.SQLBuilder().t.Template
- }
-
- func (upd *updater) String() string {
- s, err := upd.Compile()
- if err != nil {
- panic(err.Error())
- }
- return prepareQueryForDisplay(s)
- }
-
- func (upd *updater) setTable(table string) *updater {
- return upd.frame(func(uq *updaterQuery) error {
- uq.table = table
- return nil
- })
- }
-
- func (upd *updater) frame(fn func(*updaterQuery) error) *updater {
- return &updater{prev: upd, fn: fn}
- }
-
- func (upd *updater) Set(terms ...interface{}) Updater {
- return upd.frame(func(uq *updaterQuery) error {
- if uq.columnValues == nil {
- uq.columnValues = &exql.ColumnValues{}
- }
-
- if len(terms) == 1 {
- ff, vv, err := Map(terms[0], nil)
- if err == nil && len(ff) > 0 {
- cvs := make([]exql.Fragment, 0, len(ff))
- args := make([]interface{}, 0, len(vv))
-
- for i := range ff {
- cv := &exql.ColumnValue{
- Column: exql.ColumnWithName(ff[i]),
- Operator: upd.SQLBuilder().t.AssignmentOperator,
- }
-
- var localArgs []interface{}
- cv.Value, localArgs = upd.SQLBuilder().t.PlaceholderValue(vv[i])
-
- args = append(args, localArgs...)
- cvs = append(cvs, cv)
- }
-
- uq.columnValues.Insert(cvs...)
- uq.columnValuesArgs = append(uq.columnValuesArgs, args...)
-
- return nil
- }
- }
-
- cv, arguments := upd.SQLBuilder().t.setColumnValues(terms)
- uq.columnValues.Insert(cv.ColumnValues...)
- uq.columnValuesArgs = append(uq.columnValuesArgs, arguments...)
- return nil
- })
- }
-
- func (upd *updater) Amend(fn func(string) string) Updater {
- return upd.frame(func(uq *updaterQuery) error {
- uq.amendFn = fn
- return nil
- })
- }
-
- func (upd *updater) Arguments() []interface{} {
- uq, err := upd.build()
- if err != nil {
- return nil
- }
- return uq.arguments()
- }
-
- func (upd *updater) Where(terms ...interface{}) Updater {
- return upd.frame(func(uq *updaterQuery) error {
- uq.where, uq.whereArgs = &exql.Where{}, []interface{}{}
- return uq.and(upd.SQLBuilder(), terms...)
- })
- }
-
- func (upd *updater) And(terms ...interface{}) Updater {
- return upd.frame(func(uq *updaterQuery) error {
- return uq.and(upd.SQLBuilder(), terms...)
- })
- }
-
- func (upd *updater) Prepare() (*sql.Stmt, error) {
- return upd.PrepareContext(upd.SQLBuilder().sess.Context())
- }
-
- func (upd *updater) PrepareContext(ctx context.Context) (*sql.Stmt, error) {
- uq, err := upd.build()
- if err != nil {
- return nil, err
- }
- return upd.SQLBuilder().sess.StatementPrepare(ctx, uq.statement())
- }
-
- func (upd *updater) Exec() (sql.Result, error) {
- return upd.ExecContext(upd.SQLBuilder().sess.Context())
- }
-
- func (upd *updater) ExecContext(ctx context.Context) (sql.Result, error) {
- uq, err := upd.build()
- if err != nil {
- return nil, err
- }
- return upd.SQLBuilder().sess.StatementExec(ctx, uq.statement(), uq.arguments()...)
- }
-
- func (upd *updater) Limit(limit int) Updater {
- return upd.frame(func(uq *updaterQuery) error {
- uq.limit = limit
- return nil
- })
- }
-
- func (upd *updater) statement() (*exql.Statement, error) {
- iq, err := upd.build()
- if err != nil {
- return nil, err
- }
- return iq.statement(), nil
- }
-
- func (upd *updater) build() (*updaterQuery, error) {
- uq, err := immutable.FastForward(upd)
- if err != nil {
- return nil, err
- }
- return uq.(*updaterQuery), nil
- }
-
- func (upd *updater) Compile() (string, error) {
- s, err := upd.statement()
- if err != nil {
- return "", err
- }
- return s.Compile(upd.template())
- }
-
- func (upd *updater) Prev() immutable.Immutable {
- if upd == nil {
- return nil
- }
- return upd.prev
- }
-
- func (upd *updater) Fn(in interface{}) error {
- if upd.fn == nil {
- return nil
- }
- return upd.fn(in.(*updaterQuery))
- }
-
- func (upd *updater) Base() interface{} {
- return &updaterQuery{}
- }
|