Browse Source

Add relations on structs/ptr outside of db-models (e.g from time package),

Joachim M. Giæver 2 years ago
parent
commit
d74d3209da
7 changed files with 193 additions and 73 deletions
  1. 8 4
      field.go
  2. 0 2
      mappable.go
  3. 4 0
      mapper.go
  4. 4 0
      relation.go
  5. 17 18
      repository.go
  6. 112 13
      sql.go
  7. 48 36
      table.go

+ 8 - 4
field.go

@@ -6,11 +6,11 @@ import (
 )
 
 // fieldType is the type we expect
-type fieldType reflect.Kind
+type fieldType uint8
 
 // Relation or Column defines type of mapping
 const (
-	Relation fieldType = (fieldType)(reflect.Struct)
+	Relation fieldType = iota << 1 // fieldType = (fieldType)(reflect.Struct)
 	Column
 )
 
@@ -19,11 +19,15 @@ type field struct {
 	sf reflect.StructField
 	t  reflect.Type
 	v  reflect.Value
+	ft fieldType
 }
 
 func (f *field) Make() reflect.Value {
 	o := reflect.New(f.MakeType()).Elem()
-	o.Set(reflect.MakeMap(o.Type()))
+	switch f.MakeType().Kind() {
+	case reflect.Map:
+		o.Set(reflect.MakeMap(o.Type()))
+	}
 	return o
 }
 
@@ -45,7 +49,7 @@ func (f *field) getFieldType() string {
 
 // getType returns the type; Relation or Column
 func (f *field) getType() fieldType {
-	return (fieldType)(f.t.Kind())
+	return f.ft //(fieldType)(f.t.Kind())
 }
 
 // getKind returns the actual reflect.Kind

+ 0 - 2
mappable.go

@@ -14,8 +14,6 @@ type MappableInterface interface {
 	GetColumnMapperFn() MapperFn
 }
 
-type MappedCollection map[uint64]MappableInterface
-
 type MapperFn func(string) string
 
 type Mappable struct {

+ 4 - 0
mapper.go

@@ -211,6 +211,7 @@ func (m *mapper) mapField(t *table, csf reflect.StructField, cv reflect.Value) f
 					sf: csf,
 					t:  x.Elem(),
 					v:  cv,
+					ft: Relation,
 				}
 			}
 		}
@@ -219,6 +220,7 @@ func (m *mapper) mapField(t *table, csf reflect.StructField, cv reflect.Value) f
 				sf: csf,
 				t:  csf.Type.Elem(),
 				v:  cv,
+				ft: Relation,
 			}
 		}
 	case reflect.Struct:
@@ -227,6 +229,7 @@ func (m *mapper) mapField(t *table, csf reflect.StructField, cv reflect.Value) f
 				sf: csf,
 				t:  csf.Type,
 				v:  cv,
+				ft: Relation,
 			}
 		}
 	}
@@ -234,6 +237,7 @@ func (m *mapper) mapField(t *table, csf reflect.StructField, cv reflect.Value) f
 		sf: csf,
 		t:  csf.Type,
 		v:  cv,
+		ft: Column,
 	}
 }
 

+ 4 - 0
relation.go

@@ -29,6 +29,10 @@ func (r relation) getAlias(q bool) string {
 	return r.f.getFieldName()
 }
 
+func (r relation) getNameAs(q bool) string {
+	return r.getName(q) + " AS " + r.getAlias(q)
+}
+
 // relations holds all relation on types
 type relations struct {
 	rmap map[relType][]relation

+ 17 - 18
repository.go

@@ -26,7 +26,7 @@ func Repository() *repository {
 }
 
 func (r *repository) FetchFirst(db *conn.DB, i MappableInterface) (MappableInterface, error) {
-	res, err := r.Fetch(db, i, 1)
+	res, err := r.Fetch(db, i, (Cond{Method: Limit}).SetVar(1))
 
 	if len(res) == 1 {
 		return res[0], err
@@ -35,21 +35,14 @@ func (r *repository) FetchFirst(db *conn.DB, i MappableInterface) (MappableInter
 	return nil, err
 }
 
-func (r *repository) Fetch(db *conn.DB, i MappableInterface, offlim ...int) ([]MappableInterface, error) {
+func (r *repository) Fetch(db *conn.DB, i MappableInterface, conds ...Cond) ([]MappableInterface, error) {
 	tbl := ctx.getTbl(i)
 	b := getSelectBuilder(tbl)._select(tbl)._where(tbl, i)
-
-	if len(offlim) == 1 {
-		b.Limit(offlim[0])
-	} else if len(offlim) > 1 {
-		b.Limit(offlim[1])
-		b.Offset(offlim[0])
-	}
-
+	b._extra(tbl, conds...)
 	return r.query(db, tbl, b)
 }
 
-func (r *repository) FetchRelated(i MappableInterface, n string) (interface{}, error) {
+func (r *repository) FetchRelated(i MappableInterface, n string, conds ...Cond) (interface{}, error) {
 	tbl := ctx.getTbl(i)
 
 	for rtype, rels := range tbl.getRelations() {
@@ -58,9 +51,10 @@ func (r *repository) FetchRelated(i MappableInterface, n string) (interface{}, e
 				switch rtype {
 				case hasMany, belongsTo:
 					b := getSelectBuilder(tbl)
-					b._select(rel.table)
+					b._select(rel)
 					b._join(tbl, rel)
 					b._wherePrimaryOrElse(tbl, i)
+					b._extra(tbl, conds...)
 
 					switch rel.f.MakeType().Kind() {
 					case reflect.Map, reflect.Slice:
@@ -69,6 +63,10 @@ func (r *repository) FetchRelated(i MappableInterface, n string) (interface{}, e
 						b.Limit(1)
 					}
 
+					/*if n == "parentUnitOption" {
+						fmt.Println(b.Build())
+					}*/
+
 					res, err := r.query(i.GetDB(), rel.table, b)
 
 					if err != nil {
@@ -79,7 +77,7 @@ func (r *repository) FetchRelated(i MappableInterface, n string) (interface{}, e
 
 				case hasOne:
 					b := getSelectBuilder(tbl)
-					b._select(rel.table)
+					b._select(rel)
 					b._join(tbl, rel, sqlbuilder.RightJoin)
 					b._wherePrimaryOrElse(tbl, i)
 					b.Limit(1)
@@ -131,18 +129,19 @@ func (r *repository) mapScan(i MappableInterface, rows *sqlx.Rows, tbl *table) (
 	res := make(map[string]interface{})
 
 	if err := rows.MapScan(res); err != nil {
-		fmt.Println("Error:", err)
 		return nil, err
 	}
 
 	for n, s := range res {
 		if fn := tbl.hasColumn(n).getFieldName(); len(fn) != 0 {
-			if _, err := tbl.CallMethod(i, "Set"+strings.Title(fn), s); err != nil {
-				fmt.Println("ERROR", err)
-				//return nil, err
+			if s != nil {
+				if _, err := tbl.CallMethod(i, "Set"+strings.Title(fn), s); err != nil {
+					fmt.Println(err)
+					//return nil, err
+				}
 			}
 		} else {
-			fmt.Println("MISSING FIELD NAME")
+			fmt.Println("MISSING FIELD NAME", n, s)
 		}
 	}
 

+ 112 - 13
sql.go

@@ -2,6 +2,7 @@ package orm
 
 import (
 	"fmt"
+	"reflect"
 	"strings"
 
 	"github.com/huandu/go-sqlbuilder"
@@ -18,19 +19,47 @@ const (
 	Delete
 )
 
-type selectBuilder struct {
-	*sqlbuilder.SelectBuilder
+type cmethod uint8
+type ctype uint8
+
+const (
+	And ctype = iota
+	Or
+	Between
+	NotBetween
+	EqualThan
+	NotEqualThan
+	GreaterThan
+	GreatenThanEqual
+	LessThan
+	LessThanEqual
+	Like
+	NotLike
+	In
+	NotIn
+	Null
+	NotNull
+)
+
+const (
+	Limit cmethod = iota
+	Where
+)
+
+type Cond struct {
+	Method cmethod
+	Type   ctype
+	vars   []interface{}
 }
 
-/*type selectBuilder struct {
-	b sqlbuilder.Builder
+func (c Cond) SetVar(v ...interface{}) Cond {
+	c.vars = append(c.vars, v...)
+	return c
 }
+
 type selectBuilder struct {
-	b sqlbuilder.Builder
+	*sqlbuilder.SelectBuilder
 }
-type selectBuilder struct {
-	b sqlbuilder.Builder
-}*/
 
 func getSelectBuilder(tbl *table) *selectBuilder {
 	return &selectBuilder{
@@ -38,12 +67,81 @@ func getSelectBuilder(tbl *table) *selectBuilder {
 	}
 }
 
-func (sb *selectBuilder) _select(tbls ...*table) *selectBuilder {
+func (sb *selectBuilder) getXint64(v reflect.Value, unsigned bool) interface{} {
+	switch v.Kind() {
+	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+		if !unsigned {
+			return int(v.Uint())
+		}
+		return v.Uint()
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+		if unsigned {
+			return uint64(v.Int())
+		}
+		return v.Int()
+	}
+
+	return 0
+}
+
+func (sb *selectBuilder) _extra(tbl *table, extra ...Cond) *selectBuilder {
+	for _, cond := range extra {
+		if len(cond.vars) == 0 {
+			continue
+		}
+		for i, v := range cond.vars {
+			switch v.(type) {
+			case string:
+				if s := strings.SplitAfterN(v.(string), ".", 2); len(s) > 1 {
+					fmt.Println(s, s[0][0:len(s[0])-1])
+					for _, rels := range tbl.getRelations() {
+						for _, rel := range rels {
+							if rel.f.getFieldName() == s[0][0:len(s[0])-1] {
+								cond.vars[i] = sqlbuilder.Raw(rel.getAlias(true) + "." + SqlFlavor.Quote(s[1]))
+							}
+						}
+					}
+					cond.vars[i] = v
+				}
+			}
+		}
+		switch cond.Method {
+		case Limit:
+			if len(cond.vars) > 0 {
+				switch cond.vars[0].(type) {
+				case int, uint:
+					sb.Limit(cond.vars[0].(int))
+				}
+			}
+		case Where:
+			switch cond.Type {
+			case Null:
+				sb.Where(sb.IsNull(cond.vars[0].(string)))
+			case NotNull:
+				sb.Where(sb.IsNotNull(cond.vars[0].(string)))
+			}
+		}
+	}
+	//fmt.Println(sb)
+	return sb
+
+}
+
+func (sb *selectBuilder) _select(tbls ...interface{}) *selectBuilder {
 	cols := []string{}
 
 	for _, tbl := range tbls {
-		for _, col := range tbl.getColumns() {
-			cols = append(cols, col.getName(true, tbl))
+		switch tbl.(type) {
+		case relation:
+			rtbl := tbl.(relation)
+			for _, col := range rtbl.table.getColumns() {
+				cols = append(cols, col.getName(true, rtbl))
+			}
+		case *table:
+			rtbl := tbl.(*table)
+			for _, col := range rtbl.getColumns() {
+				cols = append(cols, col.getName(true, rtbl))
+			}
 		}
 	}
 
@@ -66,6 +164,7 @@ outerloop:
 					continue outerloop
 				}
 			}
+			fmt.Println(fn, val)
 			where = append(where, sb.Equal(col.getName(true, tbl), val[0]))
 		} else {
 			fmt.Println(err)
@@ -95,9 +194,9 @@ func (sb *selectBuilder) _wherePrimaryOrElse(tbl *table, i MappableInterface) *s
 func (sb *selectBuilder) _join(tbl *table, rel relation, opt ...sqlbuilder.JoinOption) *selectBuilder {
 
 	if len(opt) == 0 {
-		sb.Join(rel.getNameAs(true), sb.Equal(rel.on.getName(true, rel.table), sqlbuilder.Raw(rel.key.getName(true, tbl))))
+		sb.Join(rel.getNameAs(true), sb.Equal(rel.on.getName(true, rel), sqlbuilder.Raw(rel.key.getName(true, tbl))))
 	} else {
-		sb.JoinWithOption(opt[0], rel.getNameAs(true), sb.Equal(rel.on.getName(true, rel.table), sqlbuilder.Raw(rel.key.getName(true, tbl))))
+		sb.JoinWithOption(opt[0], rel.getNameAs(true), sb.Equal(rel.on.getName(true, rel), sqlbuilder.Raw(rel.key.getName(true, tbl))))
 	}
 	return sb
 }

+ 48 - 36
table.go

@@ -5,9 +5,8 @@ import (
 	"math"
 	"reflect"
 	"strconv"
+	"strings"
 	"sync"
-
-	"github.com/go-openapi/inflect"
 )
 
 type table struct {
@@ -24,8 +23,7 @@ func (t *table) Make() MappableInterface {
 	return reflect.New(t.getType()).Interface().(MappableInterface)
 }
 
-func (t *table) CallMethod(i MappableInterface, n string, args ...interface{}) ([]interface{}, error) {
-	var ret []interface{}
+func (t *table) CallMethod(i MappableInterface, n string, args ...interface{}) (ret []interface{}, err error) {
 	if t.getValue(true).MethodByName(n).IsValid() {
 		fn := reflect.ValueOf(i).MethodByName(n)
 		fnt := fn.Type()
@@ -77,22 +75,38 @@ func (t *table) CallMethod(i MappableInterface, n string, args ...interface{}) (
 			in = append(in, argv.Convert(inType))
 		}
 
-		var err error = nil
-		out := fn.Call(in)[0:fnt.NumOut()]
-
-		for _, val := range out {
-			switch val.Kind() {
-			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
-				ret = append(ret, val.Uint())
-			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
-				ret = append(ret, val.Int())
-			case reflect.Float32, reflect.Float64:
-				ret = append(ret, val.Float())
-			case reflect.String:
-				ret = append(ret, val.String())
-			case reflect.Interface:
-				if !val.IsNil() && val.CanInterface() && val.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
-					err = val.Interface().(error)
+		var out []reflect.Value
+
+		defer func() {
+			if r := recover(); r != nil {
+				fmt.Println("Recovered: ", r)
+				err = fmt.Errorf("Recovered: %v", r)
+			}
+		}()
+
+		out = fn.Call(in)[0:fnt.NumOut()]
+
+		if strings.HasPrefix(n, "Get") {
+			for _, val := range out {
+				switch val.Kind() {
+				case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+					ret = append(ret, val.Uint())
+				case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+					ret = append(ret, val.Int())
+				case reflect.Float32, reflect.Float64:
+					ret = append(ret, val.Float())
+				case reflect.String:
+					ret = append(ret, val.String())
+				case reflect.Interface:
+					if !val.IsNil() && val.CanInterface() && val.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
+						err = val.Interface().(error)
+					}
+				case reflect.Struct, reflect.Ptr:
+					if val.CanInterface() && !val.IsNil() {
+						ret = append(ret, val.Interface())
+					} else if val.IsNil() {
+						ret = append(ret, nil)
+					}
 				}
 			}
 		}
@@ -201,7 +215,12 @@ func (t *table) addField(f field, cbCh chan<- mapperCallback) {
 						to:   rtbl,
 						fn: func(tbl *table) {
 							key := tbl.getPrimaryKey()
-							self.rels.addRelation(belongsTo, relation{tbl, f, *key, *c})
+							switch f.MakeType().Kind() {
+							case reflect.Map, reflect.Slice:
+								self.rels.addRelation(belongsTo, relation{tbl, f, *c, *key})
+							default:
+								self.rels.addRelation(belongsTo, relation{tbl, f, *key, *c})
+							}
 						},
 					}
 				} else {
@@ -223,24 +242,12 @@ func (t *table) addField(f field, cbCh chan<- mapperCallback) {
 							// Check for relation on column mane
 							if c := tbl.hasColumn(cn); c != nil {
 
-								// Predict the relations is «hasOne»
 								has := hasOne
-
-								// Try to predict (or simply guess) with pluralization, if «hasMany»
-								if inflect.Pluralize(f.getFieldName()) == f.getFieldName() {
+								switch f.MakeType().Kind() {
+								case reflect.Map, reflect.Slice:
 									has = hasMany
 								}
 
-								// Override with tagging if specified
-								if tag, ok := f.getTag("has"); ok {
-									switch tag {
-									case "many":
-										has = hasMany
-									case "one":
-										has = hasOne
-									}
-								}
-
 								self.rels.addRelation(has, relation{tbl, f, *c, *pkey})
 							}
 						},
@@ -258,7 +265,7 @@ func (t *table) addField(f field, cbCh chan<- mapperCallback) {
 			fallthrough // Support all Int types
 		case reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
 			fallthrough // Support all Float and Complex types
-		case reflect.String, reflect.Bool:
+		case reflect.String, reflect.Bool, reflect.Struct:
 			// Support string and boolean
 
 			// Map column name
@@ -272,6 +279,11 @@ func (t *table) addField(f field, cbCh chan<- mapperCallback) {
 			t.cols = append(t.cols, column{
 				f, dbf,
 			})
+		case reflect.Ptr:
+			f.t = f.t.Elem()
+			f.v = f.v.Elem()
+			t.addField(f, cbCh)
+
 		default:
 			fmt.Println(t.getStructName(), "not supporting", f)
 		}