123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288 |
- package sqlbuilder
-
- import (
- "context"
- "database/sql"
-
- "upper.io/db.v3/internal/immutable"
- "upper.io/db.v3/internal/sqladapter/exql"
- )
-
- type inserterQuery struct {
- table string
- enqueuedValues [][]interface{}
- returning []exql.Fragment
- columns []exql.Fragment
- values []*exql.Values
- arguments []interface{}
- extra string
- amendFn func(string) string
- }
-
- func (iq *inserterQuery) processValues() ([]*exql.Values, []interface{}, error) {
- var values []*exql.Values
- var arguments []interface{}
-
- var mapOptions *MapOptions
- if len(iq.enqueuedValues) > 1 {
- mapOptions = &MapOptions{IncludeZeroed: true, IncludeNil: true}
- }
-
- for _, enqueuedValue := range iq.enqueuedValues {
- if len(enqueuedValue) == 1 {
- // If and only if we passed one argument to Values.
- ff, vv, err := Map(enqueuedValue[0], mapOptions)
-
- if err == nil {
- // If we didn't have any problem with mapping we can convert it into
- // columns and values.
- columns, vals, args, _ := toColumnsValuesAndArguments(ff, vv)
-
- values, arguments = append(values, vals), append(arguments, args...)
-
- if len(iq.columns) == 0 {
- for _, c := range columns.Columns {
- iq.columns = append(iq.columns, c)
- }
- }
- continue
- }
-
- // The only error we can expect without exiting is this argument not
- // being a map or struct, in which case we can continue.
- if err != ErrExpectingPointerToEitherMapOrStruct {
- return nil, nil, err
- }
- }
-
- if len(iq.columns) == 0 || len(enqueuedValue) == len(iq.columns) {
- arguments = append(arguments, enqueuedValue...)
-
- l := len(enqueuedValue)
- placeholders := make([]exql.Fragment, l)
- for i := 0; i < l; i++ {
- placeholders[i] = exql.RawValue(`?`)
- }
- values = append(values, exql.NewValueGroup(placeholders...))
- }
- }
-
- return values, arguments, nil
- }
-
- func (iq *inserterQuery) statement() *exql.Statement {
- stmt := &exql.Statement{
- Type: exql.Insert,
- Table: exql.TableWithName(iq.table),
- }
-
- if len(iq.values) > 0 {
- stmt.Values = exql.JoinValueGroups(iq.values...)
- }
-
- if len(iq.columns) > 0 {
- stmt.Columns = exql.JoinColumns(iq.columns...)
- }
-
- if len(iq.returning) > 0 {
- stmt.Returning = exql.ReturningColumns(iq.returning...)
- }
-
- stmt.SetAmendment(iq.amendFn)
-
- return stmt
- }
-
- type inserter struct {
- builder *sqlBuilder
-
- fn func(*inserterQuery) error
- prev *inserter
- }
-
- var _ = immutable.Immutable(&inserter{})
-
- func (ins *inserter) SQLBuilder() *sqlBuilder {
- if ins.prev == nil {
- return ins.builder
- }
- return ins.prev.SQLBuilder()
- }
-
- func (ins *inserter) template() *exql.Template {
- return ins.SQLBuilder().t.Template
- }
-
- func (ins *inserter) String() string {
- s, err := ins.Compile()
- if err != nil {
- panic(err.Error())
- }
- return prepareQueryForDisplay(s)
- }
-
- func (ins *inserter) frame(fn func(*inserterQuery) error) *inserter {
- return &inserter{prev: ins, fn: fn}
- }
-
- func (ins *inserter) Batch(n int) *BatchInserter {
- return newBatchInserter(ins, n)
- }
-
- func (ins *inserter) Amend(fn func(string) string) Inserter {
- return ins.frame(func(iq *inserterQuery) error {
- iq.amendFn = fn
- return nil
- })
- }
-
- func (ins *inserter) Arguments() []interface{} {
- iq, err := ins.build()
- if err != nil {
- return nil
- }
- return iq.arguments
- }
-
- func (ins *inserter) Returning(columns ...string) Inserter {
- return ins.frame(func(iq *inserterQuery) error {
- columnsToFragments(&iq.returning, columns)
- return nil
- })
- }
-
- func (ins *inserter) Exec() (sql.Result, error) {
- return ins.ExecContext(ins.SQLBuilder().sess.Context())
- }
-
- func (ins *inserter) ExecContext(ctx context.Context) (sql.Result, error) {
- iq, err := ins.build()
- if err != nil {
- return nil, err
- }
- return ins.SQLBuilder().sess.StatementExec(ctx, iq.statement(), iq.arguments...)
- }
-
- func (ins *inserter) Prepare() (*sql.Stmt, error) {
- return ins.PrepareContext(ins.SQLBuilder().sess.Context())
- }
-
- func (ins *inserter) PrepareContext(ctx context.Context) (*sql.Stmt, error) {
- iq, err := ins.build()
- if err != nil {
- return nil, err
- }
- return ins.SQLBuilder().sess.StatementPrepare(ctx, iq.statement())
- }
-
- func (ins *inserter) Query() (*sql.Rows, error) {
- return ins.QueryContext(ins.SQLBuilder().sess.Context())
- }
-
- func (ins *inserter) QueryContext(ctx context.Context) (*sql.Rows, error) {
- iq, err := ins.build()
- if err != nil {
- return nil, err
- }
- return ins.SQLBuilder().sess.StatementQuery(ctx, iq.statement(), iq.arguments...)
- }
-
- func (ins *inserter) QueryRow() (*sql.Row, error) {
- return ins.QueryRowContext(ins.SQLBuilder().sess.Context())
- }
-
- func (ins *inserter) QueryRowContext(ctx context.Context) (*sql.Row, error) {
- iq, err := ins.build()
- if err != nil {
- return nil, err
- }
- return ins.SQLBuilder().sess.StatementQueryRow(ctx, iq.statement(), iq.arguments...)
- }
-
- func (ins *inserter) Iterator() Iterator {
- return ins.IteratorContext(ins.SQLBuilder().sess.Context())
- }
-
- func (ins *inserter) IteratorContext(ctx context.Context) Iterator {
- rows, err := ins.QueryContext(ctx)
- return &iterator{ins.SQLBuilder().sess, rows, err}
- }
-
- func (ins *inserter) Into(table string) Inserter {
- return ins.frame(func(iq *inserterQuery) error {
- iq.table = table
- return nil
- })
- }
-
- func (ins *inserter) Columns(columns ...string) Inserter {
- return ins.frame(func(iq *inserterQuery) error {
- columnsToFragments(&iq.columns, columns)
- return nil
- })
- }
-
- func (ins *inserter) Values(values ...interface{}) Inserter {
- return ins.frame(func(iq *inserterQuery) error {
- iq.enqueuedValues = append(iq.enqueuedValues, values)
- return nil
- })
- }
-
- func (ins *inserter) statement() (*exql.Statement, error) {
- iq, err := ins.build()
- if err != nil {
- return nil, err
- }
- return iq.statement(), nil
- }
-
- func (ins *inserter) build() (*inserterQuery, error) {
- iq, err := immutable.FastForward(ins)
- if err != nil {
- return nil, err
- }
- ret := iq.(*inserterQuery)
- ret.values, ret.arguments, err = ret.processValues()
- if err != nil {
- return nil, err
- }
- return ret, nil
- }
-
- func (ins *inserter) Compile() (string, error) {
- s, err := ins.statement()
- if err != nil {
- return "", err
- }
- return s.Compile(ins.template())
- }
-
- func (ins *inserter) Prev() immutable.Immutable {
- if ins == nil {
- return nil
- }
- return ins.prev
- }
-
- func (ins *inserter) Fn(in interface{}) error {
- if ins.fn == nil {
- return nil
- }
- return ins.fn(in.(*inserterQuery))
- }
-
- func (ins *inserter) Base() interface{} {
- return &inserterQuery{}
- }
-
- func columnsToFragments(dst *[]exql.Fragment, columns []string) error {
- l := len(columns)
- f := make([]exql.Fragment, l)
- for i := 0; i < l; i++ {
- f[i] = exql.ColumnWithName(columns[i])
- }
- *dst = append(*dst, f...)
- return nil
- }
|