package orm import ( "fmt" "reflect" "strings" "git.giaever.org/bnb.hosting/orm/conn" "github.com/huandu/go-sqlbuilder" "github.com/jmoiron/sqlx" ) type repository struct { dbc map[string]MappableInterface } var r *repository = nil func Repository() *repository { if r == nil { r = &repository{ make(map[string]MappableInterface), } } return r } func (r *repository) FetchFirst(db *conn.DB, i MappableInterface) (MappableInterface, error) { res, err := r.Fetch(db, i, (Cond{Method: Limit}).SetVar(1)) if len(res) == 1 { return res[0], err } return nil, err } func (r *repository) Fetch(db *conn.DB, i MappableInterface, conds ...Cond) ([]MappableInterface, error) { tbl := ctx.getTbl(i) b := getSelectBuilder(tbl)._select(tbl)._where(tbl, i) b._extra(tbl, conds...) return r.query(db, tbl, b) } func (r *repository) FetchRelated(i MappableInterface, n string, conds ...Cond) (interface{}, error) { tbl := ctx.getTbl(i) for rtype, rels := range tbl.getRelations() { for _, rel := range rels { if rel.f.getFieldName() == n { switch rtype { case hasMany, belongsTo: b := getSelectBuilder(tbl) b._select(rel) b._join(tbl, rel) b._wherePrimaryOrElse(tbl, i) b._extra(tbl, conds...) switch rel.f.MakeType().Kind() { case reflect.Map, reflect.Slice: // No limit? default: b.Limit(1) } /*if n == "parentUnitOption" { fmt.Println(b.Build()) }*/ res, err := r.query(i.GetDB(), rel.table, b) if err != nil { return nil, err } return r.mapScanResultToInterface(i, tbl, rel, res), nil case hasOne: b := getSelectBuilder(tbl) b._select(rel) b._join(tbl, rel, sqlbuilder.RightJoin) b._wherePrimaryOrElse(tbl, i) b.Limit(1) res, err := r.query(i.GetDB(), rel.table, b) if err != nil || len(res) == 0 { return nil, err } return res[0], nil } } } } return nil, nil } func (r *repository) mapScanResultToInterface(i MappableInterface, tbl *table, rel relation, res []MappableInterface) interface{} { m := rel.f.Make() switch m.Kind() { case reflect.Map: for midx, rel := range res { if val, err := tbl.CallMethod(i, "GetId"); err != nil { switch val[0].(type) { case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64: m.SetMapIndex(reflect.ValueOf(val[0].(uint64)), reflect.ValueOf(rel)) } } else { m.SetMapIndex(reflect.ValueOf(uint64(midx)), reflect.ValueOf(rel)) } } case reflect.Slice: for _, rel := range res { m = reflect.Append(m, reflect.ValueOf(rel)) } default: if len(res) != 0 { m.Set(reflect.ValueOf(res[0])) } } return m.Interface() } func (r *repository) mapScan(i MappableInterface, rows *sqlx.Rows, tbl *table) (MappableInterface, error) { res := make(map[string]interface{}) if err := rows.MapScan(res); err != nil { return nil, err } for n, s := range res { if fn := tbl.hasColumn(n).getFieldName(); len(fn) != 0 { if s != nil { if _, err := tbl.CallMethod(i, "Set"+strings.Title(fn), s); err != nil { fmt.Println(err) //return nil, err } } } else { fmt.Println("MISSING FIELD NAME", n, s) } } if col := tbl.getPrimaryKey(); col != nil { if v, err := tbl.CallMethod(i, "Get"+strings.Title(col.getFieldName())); err == nil { idx := fmt.Sprintf("%s_%#v", tbl.getStructName(), v[0]) if ni, ok := r.dbc[idx]; ok { return ni, nil } else { r.dbc[idx] = i } } } return i, nil } func (r *repository) query(db *conn.DB, tbl *table, sql sqlbuilder.Builder) ([]MappableInterface, error) { var ret []MappableInterface q, args := sql.Build() rows, err := db.Queryx(q, args...) for err != nil { return nil, err } //defer rows.Close() for rows.Next() { i, err := r.mapScan(tbl.Make(), rows, tbl) if err == nil { i.SetDB(db) ret = append(ret, i) } else { fmt.Println(err) return nil, err } } return ret, nil }