123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320 |
- package sqladapter
-
- import (
- "errors"
- "fmt"
- "reflect"
-
- "upper.io/db.v3"
- "upper.io/db.v3/internal/sqladapter/exql"
- "upper.io/db.v3/lib/reflectx"
- )
-
- var mapper = reflectx.NewMapper("db")
-
- var errMissingPrimaryKeys = errors.New("Table %q has no primary keys")
-
- // Collection represents a SQL table.
- type Collection interface {
- PartialCollection
- BaseCollection
- }
-
- // PartialCollection defines methods that must be implemented by the adapter.
- type PartialCollection interface {
- // Database returns the parent database.
- Database() Database
-
- // Name returns the name of the table.
- Name() string
-
- // Insert inserts a new item into the collection.
- Insert(interface{}) (interface{}, error)
- }
-
- // BaseCollection provides logic for methods that can be shared across all SQL
- // adapters.
- type BaseCollection interface {
- // Exists returns true if the collection exists.
- Exists() bool
-
- // Find creates and returns a new result set.
- Find(conds ...interface{}) db.Result
-
- // Truncate removes all items on the collection.
- Truncate() error
-
- // InsertReturning inserts a new item and updates it with the
- // actual values from the database.
- InsertReturning(interface{}) error
-
- // UpdateReturning updates an item and returns the actual values from the
- // database.
- UpdateReturning(interface{}) error
-
- // PrimaryKeys returns the table's primary keys.
- PrimaryKeys() []string
- }
-
- type condsFilter interface {
- FilterConds(...interface{}) []interface{}
- }
-
- // collection is the implementation of Collection.
- type collection struct {
- BaseCollection
- PartialCollection
-
- pk []string
- err error
- }
-
- var (
- _ = Collection(&collection{})
- )
-
- // NewBaseCollection returns a collection with basic methods.
- func NewBaseCollection(p PartialCollection) BaseCollection {
- c := &collection{PartialCollection: p}
- c.pk, c.err = c.Database().PrimaryKeys(c.Name())
- return c
- }
-
- // PrimaryKeys returns the collection's primary keys, if any.
- func (c *collection) PrimaryKeys() []string {
- return c.pk
- }
-
- func (c *collection) filterConds(conds ...interface{}) []interface{} {
- if tr, ok := c.PartialCollection.(condsFilter); ok {
- return tr.FilterConds(conds...)
- }
- if len(conds) == 1 && len(c.pk) == 1 {
- if id := conds[0]; IsKeyValue(id) {
- conds[0] = db.Cond{c.pk[0]: db.Eq(id)}
- }
- }
- return conds
- }
-
- // Find creates a result set with the given conditions.
- func (c *collection) Find(conds ...interface{}) db.Result {
- if c.err != nil {
- res := &Result{}
- res.setErr(c.err)
- return res
- }
- return NewResult(
- c.Database(),
- c.Name(),
- c.filterConds(conds...),
- )
- }
-
- // Exists returns true if the collection exists.
- func (c *collection) Exists() bool {
- if err := c.Database().TableExists(c.Name()); err != nil {
- return false
- }
- return true
- }
-
- // InsertReturning inserts an item and updates the given variable reference.
- func (c *collection) InsertReturning(item interface{}) error {
- if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr {
- return fmt.Errorf("Expecting a pointer but got %T", item)
- }
-
- // Grab primary keys
- pks := c.PrimaryKeys()
- if len(pks) == 0 {
- if !c.Exists() {
- return db.ErrCollectionDoesNotExist
- }
- return fmt.Errorf(errMissingPrimaryKeys.Error(), c.Name())
- }
-
- var tx DatabaseTx
- inTx := false
-
- if currTx := c.Database().Transaction(); currTx != nil {
- tx = NewDatabaseTx(c.Database())
- inTx = true
- } else {
- // Not within a transaction, let's create one.
- var err error
- tx, err = c.Database().NewDatabaseTx(c.Database().Context())
- if err != nil {
- return err
- }
- defer tx.(Database).Close()
- }
-
- // Allocate a clone of item.
- newItem := reflect.New(reflect.ValueOf(item).Elem().Type()).Interface()
- var newItemFieldMap map[string]reflect.Value
-
- itemValue := reflect.ValueOf(item)
-
- col := tx.(Database).Collection(c.Name())
-
- // Insert item as is and grab the returning ID.
- id, err := col.Insert(item)
- if err != nil {
- goto cancel
- }
- if id == nil {
- err = fmt.Errorf("InsertReturning: Could not get a valid ID after inserting. Does the %q table have a primary key?", c.Name())
- goto cancel
- }
-
- // Fetch the row that was just interted into newItem
- err = col.Find(id).One(newItem)
- if err != nil {
- goto cancel
- }
-
- switch reflect.ValueOf(newItem).Elem().Kind() {
- case reflect.Struct:
- // Get valid fields from newItem to overwrite those that are on item.
- newItemFieldMap = mapper.ValidFieldMap(reflect.ValueOf(newItem))
- for fieldName := range newItemFieldMap {
- mapper.FieldByName(itemValue, fieldName).Set(newItemFieldMap[fieldName])
- }
- case reflect.Map:
- newItemV := reflect.ValueOf(newItem).Elem()
- itemV := reflect.ValueOf(item)
- if itemV.Kind() == reflect.Ptr {
- itemV = itemV.Elem()
- }
- for _, keyV := range newItemV.MapKeys() {
- itemV.SetMapIndex(keyV, newItemV.MapIndex(keyV))
- }
- default:
- err = fmt.Errorf("InsertReturning: expecting a pointer to map or struct, got %T", newItem)
- goto cancel
- }
-
- if !inTx {
- // This is only executed if t.Database() was **not** a transaction and if
- // sess was created with sess.NewTransaction().
- return tx.Commit()
- }
-
- return err
-
- cancel:
- // This goto label should only be used when we got an error within a
- // transaction and we don't want to continue.
-
- if !inTx {
- // This is only executed if t.Database() was **not** a transaction and if
- // sess was created with sess.NewTransaction().
- tx.Rollback()
- }
- return err
- }
-
- func (c *collection) UpdateReturning(item interface{}) error {
- if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr {
- return fmt.Errorf("Expecting a pointer but got %T", item)
- }
-
- // Grab primary keys
- pks := c.PrimaryKeys()
- if len(pks) == 0 {
- if !c.Exists() {
- return db.ErrCollectionDoesNotExist
- }
- return fmt.Errorf(errMissingPrimaryKeys.Error(), c.Name())
- }
-
- var tx DatabaseTx
- inTx := false
-
- if currTx := c.Database().Transaction(); currTx != nil {
- tx = NewDatabaseTx(c.Database())
- inTx = true
- } else {
- // Not within a transaction, let's create one.
- var err error
- tx, err = c.Database().NewDatabaseTx(c.Database().Context())
- if err != nil {
- return err
- }
- defer tx.(Database).Close()
- }
-
- // Allocate a clone of item.
- defaultItem := reflect.New(reflect.ValueOf(item).Elem().Type()).Interface()
- var defaultItemFieldMap map[string]reflect.Value
-
- itemValue := reflect.ValueOf(item)
-
- conds := db.Cond{}
- for _, pk := range pks {
- conds[pk] = db.Eq(mapper.FieldByName(itemValue, pk).Interface())
- }
-
- col := tx.(Database).Collection(c.Name())
-
- err := col.Find(conds).Update(item)
- if err != nil {
- goto cancel
- }
-
- if err = col.Find(conds).One(defaultItem); err != nil {
- goto cancel
- }
-
- switch reflect.ValueOf(defaultItem).Elem().Kind() {
- case reflect.Struct:
- // Get valid fields from defaultItem to overwrite those that are on item.
- defaultItemFieldMap = mapper.ValidFieldMap(reflect.ValueOf(defaultItem))
- for fieldName := range defaultItemFieldMap {
- mapper.FieldByName(itemValue, fieldName).Set(defaultItemFieldMap[fieldName])
- }
- case reflect.Map:
- defaultItemV := reflect.ValueOf(defaultItem).Elem()
- itemV := reflect.ValueOf(item)
- if itemV.Kind() == reflect.Ptr {
- itemV = itemV.Elem()
- }
- for _, keyV := range defaultItemV.MapKeys() {
- itemV.SetMapIndex(keyV, defaultItemV.MapIndex(keyV))
- }
- default:
- panic("default")
- }
-
- if !inTx {
- // This is only executed if t.Database() was **not** a transaction and if
- // sess was created with sess.NewTransaction().
- return tx.Commit()
- }
- return err
-
- cancel:
- // This goto label should only be used when we got an error within a
- // transaction and we don't want to continue.
-
- if !inTx {
- // This is only executed if t.Database() was **not** a transaction and if
- // sess was created with sess.NewTransaction().
- tx.Rollback()
- }
- return err
- }
-
- // Truncate deletes all rows from the table.
- func (c *collection) Truncate() error {
- stmt := exql.Statement{
- Type: exql.Truncate,
- Table: exql.TableWithName(c.Name()),
- }
- if _, err := c.Database().Exec(&stmt); err != nil {
- return err
- }
- return nil
- }
|