|
@@ -0,0 +1,187 @@
|
|
|
+package orm
|
|
|
+
|
|
|
+import (
|
|
|
+ "fmt"
|
|
|
+ "reflect"
|
|
|
+ "strings"
|
|
|
+
|
|
|
+ "git.giaever.org/bnb.hosting/orm/conn"
|
|
|
+ "github.com/huandu/go-sqlbuilder"
|
|
|
+ "github.com/jmoiron/sqlx"
|
|
|
+)
|
|
|
+
|
|
|
+type repository struct {
|
|
|
+ dbc map[string]MappableInterface
|
|
|
+}
|
|
|
+
|
|
|
+var r *repository = nil
|
|
|
+
|
|
|
+func Repository() *repository {
|
|
|
+ if r == nil {
|
|
|
+ r = &repository{
|
|
|
+ make(map[string]MappableInterface),
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return r
|
|
|
+}
|
|
|
+
|
|
|
+func (r *repository) FetchFirst(db *conn.DB, i MappableInterface) (MappableInterface, error) {
|
|
|
+ res, err := r.Fetch(db, i, 1)
|
|
|
+
|
|
|
+ if len(res) == 1 {
|
|
|
+ return res[0], err
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil, err
|
|
|
+}
|
|
|
+
|
|
|
+func (r *repository) Fetch(db *conn.DB, i MappableInterface, offlim ...int) ([]MappableInterface, error) {
|
|
|
+ tbl := ctx.getTbl(i)
|
|
|
+ b := getSelectBuilder(tbl)._select(tbl)._where(tbl, i)
|
|
|
+
|
|
|
+ if len(offlim) == 1 {
|
|
|
+ b.Limit(offlim[0])
|
|
|
+ } else if len(offlim) > 1 {
|
|
|
+ b.Limit(offlim[1])
|
|
|
+ b.Offset(offlim[0])
|
|
|
+ }
|
|
|
+
|
|
|
+ return r.query(db, tbl, b)
|
|
|
+}
|
|
|
+
|
|
|
+func (r *repository) FetchRelated(i MappableInterface, n string) (interface{}, error) {
|
|
|
+ tbl := ctx.getTbl(i)
|
|
|
+
|
|
|
+ for rtype, rels := range tbl.getRelations() {
|
|
|
+ for _, rel := range rels {
|
|
|
+ if rel.f.getFieldName() == n {
|
|
|
+ switch rtype {
|
|
|
+ case hasMany, belongsTo:
|
|
|
+ b := getSelectBuilder(tbl)
|
|
|
+ b._select(rel.table)
|
|
|
+ b._join(tbl, rel)
|
|
|
+ b._wherePrimaryOrElse(tbl, i)
|
|
|
+
|
|
|
+ switch rel.f.MakeType().Kind() {
|
|
|
+ case reflect.Map, reflect.Slice:
|
|
|
+
|
|
|
+ default:
|
|
|
+ b.Limit(1)
|
|
|
+ }
|
|
|
+
|
|
|
+ res, err := r.query(i.GetDB(), rel.table, b)
|
|
|
+
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return r.mapScanResultToInterface(i, tbl, rel, res), nil
|
|
|
+
|
|
|
+ case hasOne:
|
|
|
+ b := getSelectBuilder(tbl)
|
|
|
+ b._select(rel.table)
|
|
|
+ b._join(tbl, rel, sqlbuilder.RightJoin)
|
|
|
+ b._wherePrimaryOrElse(tbl, i)
|
|
|
+ b.Limit(1)
|
|
|
+
|
|
|
+ res, err := r.query(i.GetDB(), rel.table, b)
|
|
|
+
|
|
|
+ if err != nil || len(res) == 0 {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return res[0], nil
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (r *repository) mapScanResultToInterface(i MappableInterface, tbl *table, rel relation, res []MappableInterface) interface{} {
|
|
|
+ m := rel.f.Make()
|
|
|
+
|
|
|
+ switch m.Kind() {
|
|
|
+ case reflect.Map:
|
|
|
+ for midx, rel := range res {
|
|
|
+ if val, err := tbl.CallMethod(i, "GetId"); err != nil {
|
|
|
+ switch val[0].(type) {
|
|
|
+ case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64:
|
|
|
+ m.SetMapIndex(reflect.ValueOf(val[0].(uint64)), reflect.ValueOf(rel))
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ m.SetMapIndex(reflect.ValueOf(uint64(midx)), reflect.ValueOf(rel))
|
|
|
+ }
|
|
|
+ }
|
|
|
+ case reflect.Slice:
|
|
|
+ for _, rel := range res {
|
|
|
+ m = reflect.Append(m, reflect.ValueOf(rel))
|
|
|
+ }
|
|
|
+ default:
|
|
|
+ if len(res) != 0 {
|
|
|
+ m.Set(reflect.ValueOf(res[0]))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return m.Interface()
|
|
|
+}
|
|
|
+
|
|
|
+func (r *repository) mapScan(i MappableInterface, rows *sqlx.Rows, tbl *table) (MappableInterface, error) {
|
|
|
+ res := make(map[string]interface{})
|
|
|
+
|
|
|
+ if err := rows.MapScan(res); err != nil {
|
|
|
+ fmt.Println("Error:", err)
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ for n, s := range res {
|
|
|
+ if fn := tbl.hasColumn(n).getFieldName(); len(fn) != 0 {
|
|
|
+ if _, err := tbl.CallMethod(i, "Set"+strings.Title(fn), s); err != nil {
|
|
|
+ fmt.Println("ERROR", err)
|
|
|
+
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ fmt.Println("MISSING FIELD NAME")
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if col := tbl.getPrimaryKey(); col != nil {
|
|
|
+ if v, err := tbl.CallMethod(i, "Get"+strings.Title(col.getFieldName())); err == nil {
|
|
|
+ idx := fmt.Sprintf("%s_%#v", tbl.getStructName(), v[0])
|
|
|
+ if ni, ok := r.dbc[idx]; ok {
|
|
|
+ return ni, nil
|
|
|
+ } else {
|
|
|
+ r.dbc[idx] = i
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return i, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (r *repository) query(db *conn.DB, tbl *table, sql sqlbuilder.Builder) ([]MappableInterface, error) {
|
|
|
+ var ret []MappableInterface
|
|
|
+
|
|
|
+ q, args := sql.Build()
|
|
|
+ rows, err := db.Queryx(q, args...)
|
|
|
+
|
|
|
+ for err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+ for rows.Next() {
|
|
|
+ i, err := r.mapScan(tbl.Make(), rows, tbl)
|
|
|
+
|
|
|
+ if err == nil {
|
|
|
+ i.SetDB(db)
|
|
|
+ ret = append(ret, i)
|
|
|
+ } else {
|
|
|
+ fmt.Println(err)
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return ret, nil
|
|
|
+}
|