Browse Source

Add a simple repository, sql building and querying (incl. simple relations)

Joachim M. Giæver 4 years ago
parent
commit
f755eb01de
2 changed files with 290 additions and 0 deletions
  1. 187 0
      repository.go
  2. 103 0
      sql.go

+ 187 - 0
repository.go

@@ -0,0 +1,187 @@
+package orm
+
+import (
+	"fmt"
+	"reflect"
+	"strings"
+
+	"git.giaever.org/bnb.hosting/orm/conn"
+	"github.com/huandu/go-sqlbuilder"
+	"github.com/jmoiron/sqlx"
+)
+
+type repository struct {
+	dbc map[string]MappableInterface
+}
+
+var r *repository = nil
+
+func Repository() *repository {
+	if r == nil {
+		r = &repository{
+			make(map[string]MappableInterface),
+		}
+	}
+	return r
+}
+
+func (r *repository) FetchFirst(db *conn.DB, i MappableInterface) (MappableInterface, error) {
+	res, err := r.Fetch(db, i, 1)
+
+	if len(res) == 1 {
+		return res[0], err
+	}
+
+	return nil, err
+}
+
+func (r *repository) Fetch(db *conn.DB, i MappableInterface, offlim ...int) ([]MappableInterface, error) {
+	tbl := ctx.getTbl(i)
+	b := getSelectBuilder(tbl)._select(tbl)._where(tbl, i)
+
+	if len(offlim) == 1 {
+		b.Limit(offlim[0])
+	} else if len(offlim) > 1 {
+		b.Limit(offlim[1])
+		b.Offset(offlim[0])
+	}
+
+	return r.query(db, tbl, b)
+}
+
+func (r *repository) FetchRelated(i MappableInterface, n string) (interface{}, error) {
+	tbl := ctx.getTbl(i)
+
+	for rtype, rels := range tbl.getRelations() {
+		for _, rel := range rels {
+			if rel.f.getFieldName() == n {
+				switch rtype {
+				case hasMany, belongsTo:
+					b := getSelectBuilder(tbl)
+					b._select(rel.table)
+					b._join(tbl, rel)
+					b._wherePrimaryOrElse(tbl, i)
+
+					switch rel.f.MakeType().Kind() {
+					case reflect.Map, reflect.Slice:
+						// No limit?
+					default:
+						b.Limit(1)
+					}
+
+					res, err := r.query(i.GetDB(), rel.table, b)
+
+					if err != nil {
+						return nil, err
+					}
+
+					return r.mapScanResultToInterface(i, tbl, rel, res), nil
+
+				case hasOne:
+					b := getSelectBuilder(tbl)
+					b._select(rel.table)
+					b._join(tbl, rel, sqlbuilder.RightJoin)
+					b._wherePrimaryOrElse(tbl, i)
+					b.Limit(1)
+
+					res, err := r.query(i.GetDB(), rel.table, b)
+
+					if err != nil || len(res) == 0 {
+						return nil, err
+					}
+
+					return res[0], nil
+				}
+			}
+		}
+	}
+
+	return nil, nil
+}
+
+func (r *repository) mapScanResultToInterface(i MappableInterface, tbl *table, rel relation, res []MappableInterface) interface{} {
+	m := rel.f.Make()
+
+	switch m.Kind() {
+	case reflect.Map:
+		for midx, rel := range res {
+			if val, err := tbl.CallMethod(i, "GetId"); err != nil {
+				switch val[0].(type) {
+				case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64:
+					m.SetMapIndex(reflect.ValueOf(val[0].(uint64)), reflect.ValueOf(rel))
+				}
+			} else {
+				m.SetMapIndex(reflect.ValueOf(uint64(midx)), reflect.ValueOf(rel))
+			}
+		}
+	case reflect.Slice:
+		for _, rel := range res {
+			m = reflect.Append(m, reflect.ValueOf(rel))
+		}
+	default:
+		if len(res) != 0 {
+			m.Set(reflect.ValueOf(res[0]))
+		}
+	}
+
+	return m.Interface()
+}
+
+func (r *repository) mapScan(i MappableInterface, rows *sqlx.Rows, tbl *table) (MappableInterface, error) {
+	res := make(map[string]interface{})
+
+	if err := rows.MapScan(res); err != nil {
+		fmt.Println("Error:", err)
+		return nil, err
+	}
+
+	for n, s := range res {
+		if fn := tbl.hasColumn(n).getFieldName(); len(fn) != 0 {
+			if _, err := tbl.CallMethod(i, "Set"+strings.Title(fn), s); err != nil {
+				fmt.Println("ERROR", err)
+				//return nil, err
+			}
+		} else {
+			fmt.Println("MISSING FIELD NAME")
+		}
+	}
+
+	if col := tbl.getPrimaryKey(); col != nil {
+		if v, err := tbl.CallMethod(i, "Get"+strings.Title(col.getFieldName())); err == nil {
+			idx := fmt.Sprintf("%s_%#v", tbl.getStructName(), v[0])
+			if ni, ok := r.dbc[idx]; ok {
+				return ni, nil
+			} else {
+				r.dbc[idx] = i
+			}
+		}
+	}
+
+	return i, nil
+}
+
+func (r *repository) query(db *conn.DB, tbl *table, sql sqlbuilder.Builder) ([]MappableInterface, error) {
+	var ret []MappableInterface
+
+	q, args := sql.Build()
+	rows, err := db.Queryx(q, args...)
+
+	for err != nil {
+		return nil, err
+	}
+
+	//defer rows.Close()
+	for rows.Next() {
+		i, err := r.mapScan(tbl.Make(), rows, tbl)
+
+		if err == nil {
+			i.SetDB(db)
+			ret = append(ret, i)
+		} else {
+			fmt.Println(err)
+			return nil, err
+		}
+	}
+
+	return ret, nil
+}

+ 103 - 0
sql.go

@@ -0,0 +1,103 @@
+package orm
+
+import (
+	"fmt"
+	"strings"
+
+	"github.com/huandu/go-sqlbuilder"
+)
+
+var SqlFlavor = sqlbuilder.MySQL
+
+type SqlType uint8
+
+const (
+	Select SqlType = iota << 1
+	Update
+	Insert
+	Delete
+)
+
+type selectBuilder struct {
+	*sqlbuilder.SelectBuilder
+}
+
+/*type selectBuilder struct {
+	b sqlbuilder.Builder
+}
+type selectBuilder struct {
+	b sqlbuilder.Builder
+}
+type selectBuilder struct {
+	b sqlbuilder.Builder
+}*/
+
+func getSelectBuilder(tbl *table) *selectBuilder {
+	return &selectBuilder{
+		sqlbuilder.NewSelectBuilder().From(tbl.getNameAs(true)),
+	}
+}
+
+func (sb *selectBuilder) _select(tbls ...*table) *selectBuilder {
+	cols := []string{}
+
+	for _, tbl := range tbls {
+		for _, col := range tbl.getColumns() {
+			cols = append(cols, col.getName(true, tbl))
+		}
+	}
+
+	sb.Select(cols...)
+	return sb
+}
+
+func (sb *selectBuilder) _where(tbl *table, i MappableInterface) *selectBuilder {
+	tmp := tbl.Make()
+	where := []string{}
+
+outerloop:
+	for _, col := range tbl.getColumns() {
+		fn := "Get" + strings.Title(col.getFieldName())
+		val, err := tbl.CallMethod(i, fn)
+		if err == nil {
+			tmpv, _ := tbl.CallMethod(tmp, fn)
+			for i, v := range val {
+				if v == tmpv[i] {
+					continue outerloop
+				}
+			}
+			where = append(where, sb.Equal(col.getName(true, tbl), val[0]))
+		} else {
+			fmt.Println(err)
+		}
+	}
+	sb.Where(where...)
+
+	return sb
+}
+
+func (sb *selectBuilder) _wherePrimaryOrElse(tbl *table, i MappableInterface) *selectBuilder {
+
+	if pk := tbl.getPrimaryKey(); pk != nil {
+		val, err := tbl.CallMethod(i, "Get"+strings.Title(pk.getFieldName()))
+		if err != nil {
+			sb._where(tbl, i)
+		} else {
+			sb.Where(sb.Equal(pk.getName(true, tbl), val[0]))
+		}
+	} else {
+		sb._where(tbl, i)
+	}
+
+	return sb
+}
+
+func (sb *selectBuilder) _join(tbl *table, rel relation, opt ...sqlbuilder.JoinOption) *selectBuilder {
+
+	if len(opt) == 0 {
+		sb.Join(rel.getNameAs(true), sb.Equal(rel.on.getName(true, rel.table), sqlbuilder.Raw(rel.key.getName(true, tbl))))
+	} else {
+		sb.JoinWithOption(opt[0], rel.getNameAs(true), sb.Equal(rel.on.getName(true, rel.table), sqlbuilder.Raw(rel.key.getName(true, tbl))))
+	}
+	return sb
+}