123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- package orm
- import (
- "fmt"
- "reflect"
- "strings"
- "git.giaever.org/bnb.hosting/orm/conn"
- "github.com/huandu/go-sqlbuilder"
- "github.com/jmoiron/sqlx"
- )
- type repository struct {
- dbc map[string]MappableInterface
- }
- var r *repository = nil
- func Repository() *repository {
- if r == nil {
- r = &repository{
- make(map[string]MappableInterface),
- }
- }
- return r
- }
- func (r *repository) FetchFirst(db *conn.DB, i MappableInterface) (MappableInterface, error) {
- res, err := r.Fetch(db, i, (Cond{Method: Limit}).SetVar(1))
- if len(res) == 1 {
- return res[0], err
- }
- return nil, err
- }
- func (r *repository) Fetch(db *conn.DB, i MappableInterface, conds ...Cond) ([]MappableInterface, error) {
- tbl := ctx.getTbl(i)
- b := getSelectBuilder(tbl)._select(tbl)._where(tbl, i)
- b._extra(tbl, conds...)
- return r.query(db, tbl, b)
- }
- func (r *repository) FetchRelated(i MappableInterface, n string, conds ...Cond) (interface{}, error) {
- tbl := ctx.getTbl(i)
- for rtype, rels := range tbl.getRelations() {
- for _, rel := range rels {
- if rel.f.getFieldName() == n {
- switch rtype {
- case hasMany, belongsTo:
- b := getSelectBuilder(tbl)
- b._select(rel)
- b._join(tbl, rel)
- b._wherePrimaryOrElse(tbl, i)
- b._extra(tbl, conds...)
- switch rel.f.MakeType().Kind() {
- case reflect.Map, reflect.Slice:
- // No limit?
- default:
- b.Limit(1)
- }
- /*if n == "parentUnitOption" {
- fmt.Println(b.Build())
- }*/
- res, err := r.query(i.GetDB(), rel.table, b)
- if err != nil {
- return nil, err
- }
- return r.mapScanResultToInterface(i, tbl, rel, res), nil
- case hasOne:
- b := getSelectBuilder(tbl)
- b._select(rel)
- b._join(tbl, rel, sqlbuilder.RightJoin)
- b._wherePrimaryOrElse(tbl, i)
- b.Limit(1)
- res, err := r.query(i.GetDB(), rel.table, b)
- if err != nil || len(res) == 0 {
- return nil, err
- }
- return res[0], nil
- }
- }
- }
- }
- return nil, nil
- }
- func (r *repository) mapScanResultToInterface(i MappableInterface, tbl *table, rel relation, res []MappableInterface) interface{} {
- m := rel.f.Make()
- switch m.Kind() {
- case reflect.Map:
- for midx, rel := range res {
- if val, err := tbl.CallMethod(i, "GetId"); err != nil {
- switch val[0].(type) {
- case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64:
- m.SetMapIndex(reflect.ValueOf(val[0].(uint64)), reflect.ValueOf(rel))
- }
- } else {
- m.SetMapIndex(reflect.ValueOf(uint64(midx)), reflect.ValueOf(rel))
- }
- }
- case reflect.Slice:
- for _, rel := range res {
- m = reflect.Append(m, reflect.ValueOf(rel))
- }
- default:
- if len(res) != 0 {
- m.Set(reflect.ValueOf(res[0]))
- }
- }
- return m.Interface()
- }
- func (r *repository) mapScan(i MappableInterface, rows *sqlx.Rows, tbl *table) (MappableInterface, error) {
- res := make(map[string]interface{})
- if err := rows.MapScan(res); err != nil {
- return nil, err
- }
- for n, s := range res {
- if fn := tbl.hasColumn(n).getFieldName(); len(fn) != 0 {
- if s != nil {
- if _, err := tbl.CallMethod(i, "Set"+strings.Title(fn), s); err != nil {
- fmt.Println(err)
- //return nil, err
- }
- }
- } else {
- fmt.Println("MISSING FIELD NAME", n, s)
- }
- }
- if col := tbl.getPrimaryKey(); col != nil {
- if v, err := tbl.CallMethod(i, "Get"+strings.Title(col.getFieldName())); err == nil {
- idx := fmt.Sprintf("%s_%#v", tbl.getStructName(), v[0])
- if ni, ok := r.dbc[idx]; ok {
- return ni, nil
- } else {
- r.dbc[idx] = i
- }
- }
- }
- return i, nil
- }
- func (r *repository) query(db *conn.DB, tbl *table, sql sqlbuilder.Builder) ([]MappableInterface, error) {
- var ret []MappableInterface
- q, args := sql.Build()
- rows, err := db.Queryx(q, args...)
- for err != nil {
- return nil, err
- }
- //defer rows.Close()
- for rows.Next() {
- i, err := r.mapScan(tbl.Make(), rows, tbl)
- if err == nil {
- i.SetDB(db)
- ret = append(ret, i)
- } else {
- fmt.Println(err)
- return nil, err
- }
- }
- return ret, nil
- }
|