Browse Source

Add a simple repository, sql building and querying (incl. simple relations)

Joachim M. Giæver 4 years ago
parent
commit
31d56926ae
8 changed files with 265 additions and 47 deletions
  1. 17 4
      column.go
  2. 46 12
      conn/connect.go
  3. 10 0
      field.go
  4. 8 7
      mappable.go
  5. 41 15
      mapper.go
  6. 7 0
      relation.go
  7. 8 4
      strings.go
  8. 128 5
      table.go

+ 17 - 4
column.go

@@ -8,12 +8,25 @@ type column struct {
 
 // GetName returns the name. If relation is given, it
 // will return `tbl`.`col` retpr
-func (c column) GetName(r ...*relation) string {
-	if len(r) > 0 {
-		return "`" + r[0].f.getFieldName() + "`.`" + c.ref + "`"
+func (c column) getName(q bool, rel ...interface{}) string {
+	if len(rel) > 0 {
+		switch rel[0].(type) {
+		case relation:
+			return ((rel[0]).(relation)).getAlias(q) + "." + c.getName(q)
+		case *table:
+			return ((rel[0]).(*table)).getAlias(q) + "." + c.getName(q)
+		}
 	}
 
-	return "`" + c.ref + "`"
+	if q {
+		return SqlFlavor.Quote(c.ref)
+	}
+
+	return c.ref
+}
+
+func (c column) GetAs(q bool, rel ...interface{}) string {
+	return c.getName(q, rel) + " AS "
 }
 
 // columns is the collection of all columns in a table

+ 46 - 12
conn/connect.go

@@ -3,29 +3,63 @@ package conn
 import (
 	"context"
 	"fmt"
+	"net"
+	"strings"
 
-	"git.giaever.org/bnb.hosting/database/config"
 	_ "github.com/go-sql-driver/mysql"
 	"github.com/jmoiron/sqlx"
 )
 
+type OptionQuery map[string]string
+
+func (o OptionQuery) String() string {
+	s := make([]string, 0)
+
+	for k, v := range o {
+		s = append(s, fmt.Sprintf("%s=%s",
+			k, v,
+		))
+	}
+
+	if len(s) == 0 {
+		return ""
+	}
+
+	return fmt.Sprintf("?%s",
+		strings.Join(s, "&"),
+	)
+}
+
+type DBOpts struct {
+	User string
+	Pass string
+	Host net.Addr
+	Db   string
+	Opt  OptionQuery
+}
+
 type DB struct {
 	*sqlx.DB
+	Opts *DBOpts
 }
 
-func Connect(ctx context.Context) (*DB, error) {
+func Connect(opts *DBOpts, ctx context.Context) (*DB, error) {
 	db, err := sqlx.ConnectContext(ctx, "mysql", fmt.Sprintf(
-		"%s:%s@%s(%s:%d)/%s%s",
-		config.USER,
-		config.PASS,
-		config.PROTOCOL,
-		config.HOST,
-		config.PORT,
-		config.DATABASE,
-		config.OPTIONS,
+		"%s:%s@%s(%s)/%s%s",
+		opts.User,
+		opts.Pass,
+		opts.Host.Network(),
+		opts.Host.String(),
+		opts.Db,
+		opts.Opt,
 	))
 
+	if err != nil {
+		return nil, err
+	}
+
 	return &DB{
-		db,
-	}, err
+		DB:   db,
+		Opts: opts,
+	}, nil
 }

+ 10 - 0
field.go

@@ -21,6 +21,16 @@ type field struct {
 	v  reflect.Value
 }
 
+func (f *field) Make() reflect.Value {
+	o := reflect.New(f.MakeType()).Elem()
+	o.Set(reflect.MakeMap(o.Type()))
+	return o
+}
+
+func (f *field) MakeType() reflect.Type {
+	return f.sf.Type
+}
+
 // getFieldName returns the field name within the struct,
 // e.g struct { fieldName fieldType }{}
 func (f *field) getFieldName() string {

+ 8 - 7
mappable.go

@@ -8,28 +8,29 @@ import (
 )
 
 type MappableInterface interface {
-	SetDb(db *conn.DB) MappableInterface
-	GetDb() *conn.DB
+	SetDB(*conn.DB) MappableInterface
+	GetDB() *conn.DB
 	GetTableMapperFn() MapperFn
 	GetColumnMapperFn() MapperFn
 }
 
+type MappedCollection map[uint64]MappableInterface
+
 type MapperFn func(string) string
 
 type Mappable struct {
 	db *conn.DB `db:"omit"`
 }
 
-func (m *Mappable) SetDb(db *conn.DB) MappableInterface {
+func (m *Mappable) SetDB(db *conn.DB) MappableInterface {
 	m.db = db
 	return m
 }
-
-func (m *Mappable) GetDb() *conn.DB {
+func (m *Mappable) GetDB() *conn.DB {
 	return m.db
 }
 
-func (m *Mappable) GetTableMapperFn() MapperFn {
+func (m Mappable) GetTableMapperFn() MapperFn {
 	return func(t string) string {
 		s := []byte{}
 		for i := 0; i < len(t); i++ {
@@ -49,7 +50,7 @@ func (m *Mappable) GetTableMapperFn() MapperFn {
 	}
 }
 
-func (m *Mappable) GetColumnMapperFn() MapperFn {
+func (m Mappable) GetColumnMapperFn() MapperFn {
 	return func(f string) string {
 		s := []byte{}
 		for i := 0; i < len(f); i++ {

+ 41 - 15
mapper.go

@@ -10,7 +10,7 @@ import (
 var Prefix = "db"
 
 type mapper struct {
-	lock *sync.Mutex
+	cond *sync.Cond
 	wg   *sync.WaitGroup
 	tbls tables
 	cbs  *mapperCallbacks
@@ -19,7 +19,7 @@ type mapper struct {
 
 // Global mapper context
 var ctx *mapper = &mapper{
-	lock: &sync.Mutex{},
+	cond: sync.NewCond(&sync.Mutex{}),
 	wg:   &sync.WaitGroup{},
 	tbls: make(tables),
 	cbs: &mapperCallbacks{
@@ -30,14 +30,19 @@ var ctx *mapper = &mapper{
 	init: false,
 }
 
+func Mapper() *mapper {
+	return ctx
+}
+
 // Map a table (struct) that implements MappableInterface
 // Note! Should be mapped in init-function from the that struct
 func Map(tbl MappableInterface) *mapper {
+	fmt.Println("Begin <mapping>")
 
 	// Add to working group
 	ctx.wg.Add(1)
 
-	ctx.lock.Lock()
+	ctx.cond.L.Lock()
 
 	// Kick off some routines on first call
 	if !ctx.init {
@@ -75,7 +80,7 @@ func Map(tbl MappableInterface) *mapper {
 					cb := ctx.cbs.get(l - i - 1)
 
 					// Ensure callback is ran when columns are mapped
-					if t := ctx.getTbl(cb.to); t != nil && t.isMapped() {
+					if t := ctx.getTblByName(cb.to); t != nil && t.isMapped() {
 
 						// Remove callback from slice
 						ctx.cbs.remove(l - i - 1)
@@ -94,7 +99,7 @@ func Map(tbl MappableInterface) *mapper {
 		}()
 	}
 
-	ctx.lock.Unlock()
+	ctx.cond.L.Unlock()
 
 	// Start mapping of table
 	go func() {
@@ -119,13 +124,14 @@ func Map(tbl MappableInterface) *mapper {
 func WaitInit() {
 	ctx.wg.Wait()
 	// Debug print
-	fmt.Println(ctx.tbls)
+	fmt.Println("Done <mapping>")
+	//fmt.Println(ctx.tbls)
 }
 
 // hasTable checks for table
 func (m *mapper) hasTbl(n string) bool {
-	m.lock.Lock()
-	defer m.lock.Unlock()
+	m.cond.L.Lock()
+	defer m.cond.L.Unlock()
 
 	_, ok := m.tbls[n]
 
@@ -133,9 +139,9 @@ func (m *mapper) hasTbl(n string) bool {
 }
 
 // getTbl should only be called controller; Must lock after
-func (m *mapper) getTbl(n string) *table {
-	m.lock.Lock()
-	defer m.lock.Unlock()
+func (m *mapper) getTblByName(n string) *table {
+	m.cond.L.Lock()
+	defer m.cond.L.Unlock()
 
 	if t, ok := m.tbls[n]; ok {
 		return t
@@ -144,20 +150,31 @@ func (m *mapper) getTbl(n string) *table {
 	return nil
 }
 
+func (m *mapper) getTbl(t MappableInterface) *table {
+	m.cond.L.Lock()
+	defer m.cond.L.Unlock()
+
+	for tbl, ok := m.tbls[reflect.TypeOf(t).Elem().Name()]; ; {
+		if tbl == nil || !ok || !tbl.isMapped() {
+			m.cond.Wait()
+		}
+		return tbl
+	}
+}
+
 // addTbl creates a new or returns an existing table; will write lock!
 func (m *mapper) addTbl(t MappableInterface) *table {
+	m.cond.L.Lock()
+	defer m.cond.L.Unlock()
 	rt := reflect.TypeOf(t).Elem()
 
-	m.lock.Lock()
-	defer m.lock.Unlock()
-
 	if t, ok := m.tbls[rt.Name()]; ok {
 		return t.lock()
 	}
 
 	m.tbls[rt.Name()] = &table{
 		rt:   rt,
-		rv:   reflect.ValueOf(t).Elem(),
+		rv:   reflect.ValueOf(t),
 		l:    &sync.RWMutex{},
 		tFn:  t.GetTableMapperFn(),
 		cFn:  t.GetColumnMapperFn(),
@@ -188,6 +205,15 @@ func (m *mapper) mapField(t *table, csf reflect.StructField, cv reflect.Value) f
 	case reflect.Slice:
 		fallthrough
 	case reflect.Map:
+		if x := csf.Type.Elem(); x.Kind() == reflect.Ptr {
+			if m.isPossibleRelation(x.Elem()) {
+				return field{
+					sf: csf,
+					t:  x.Elem(),
+					v:  cv,
+				}
+			}
+		}
 		if m.isPossibleRelation(csf.Type.Elem()) {
 			return field{
 				sf: csf,

+ 7 - 0
relation.go

@@ -22,6 +22,13 @@ type relation struct {
 	key column
 }
 
+func (r relation) getAlias(q bool) string {
+	if q {
+		return SqlFlavor.Quote(r.getAlias(false))
+	}
+	return r.f.getFieldName()
+}
+
 // relations holds all relation on types
 type relations struct {
 	rmap map[relType][]relation

+ 8 - 4
strings.go

@@ -32,8 +32,8 @@ func (t tables) String() string {
 func (t *table) String() string {
 	t.Lock()
 	defer t.Unlock()
-	return fmt.Sprintf("%s (`%s`):%s%s",
-		t.getStructName(), t.getName(),
+	return fmt.Sprintf("%s (%s):%s%s",
+		t.getStructName(), t.getName(false),
 		func() string { // Print columns
 			s := []string{}
 			cols := strings.Split(t.cols.String(), "\n")
@@ -47,6 +47,10 @@ func (t *table) String() string {
 					max = len(col[0])
 				}
 
+				if pk := t.getPrimaryKey(); pk != nil && "* "+pk.getName(true) == col[0] {
+					col[1] = col[1] + ", primary_key"
+				}
+
 				s = append(s, strings.Join(col, ":"))
 			}
 			for i, col := range s {
@@ -80,7 +84,7 @@ func (t *table) String() string {
 }
 
 func (c column) String() string {
-	return c.getFieldName() + ": " + c.GetName() + ", " + c.getKind().String()
+	return c.getFieldName() + ": " + c.getName(true) + ", " + c.getKind().String()
 }
 
 func (cs columns) String() string {
@@ -104,7 +108,7 @@ func (t relType) String() string {
 }
 
 func (r *relation) String() string {
-	return r.f.getFieldType() + " AS `" + r.f.getFieldName() + "` ON " + r.on.GetName(r) + " WITH " + r.key.GetName()
+	return r.getStructName() + " » " + r.getName(true) + " ON " + r.on.getName(true) + " WITH " + r.key.getName(true)
 }
 
 func (r relations) String() string {

+ 128 - 5
table.go

@@ -1,7 +1,10 @@
 package orm
 
 import (
+	"fmt"
+	"math"
 	"reflect"
+	"strconv"
 	"sync"
 
 	"github.com/go-openapi/inflect"
@@ -17,17 +20,111 @@ type table struct {
 	mapped   bool
 }
 
+func (t *table) Make() MappableInterface {
+	return reflect.New(t.getType()).Interface().(MappableInterface)
+}
+
+func (t *table) CallMethod(i MappableInterface, n string, args ...interface{}) ([]interface{}, error) {
+	var ret []interface{}
+	if t.getValue(true).MethodByName(n).IsValid() {
+		fn := reflect.ValueOf(i).MethodByName(n)
+		fnt := fn.Type()
+		in := []reflect.Value{}
+
+		if fnt.IsVariadic() && len(args) < (fnt.NumIn()-1) {
+			return ret, fmt.Errorf("To few arguments to «%s». Got «%d», expected «%d»", n, len(args), fnt.NumIn()-1)
+		} else if !fnt.IsVariadic() && len(args) != fnt.NumIn() {
+			return ret, fmt.Errorf("To few arguments to «%s». Got «%d», expected «%d»", n, len(args), fnt.NumIn()-1)
+		}
+
+		for x := 0; x < len(args); x++ {
+
+			var inType reflect.Type
+			if fnt.IsVariadic() && x >= fnt.NumIn()-1 {
+				inType = fnt.In(fnt.NumIn() - 1).Elem()
+			} else {
+				inType = fnt.In(x)
+			}
+
+			argv := reflect.ValueOf(args[x])
+
+			if !argv.IsValid() || !argv.Type().ConvertibleTo(inType) {
+
+				switch inType.Kind() {
+				case
+					reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
+					reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+					var val uint64 = 0
+					argx := args[x].([]uint8)
+					for xi := len(argx) - 1; xi >= 0; xi-- {
+						val += uint64(math.Pow(10, float64(len(argx)-xi-1))) * uint64(argx[xi]-'0')
+					}
+					args[x] = val
+					return t.CallMethod(i, n, args...)
+				case reflect.Float32, reflect.Float64:
+					if val, err := strconv.ParseFloat(string(args[x].([]byte)), 64); err == nil {
+						args[x] = val
+						return t.CallMethod(i, n, args...)
+					}
+				case reflect.Complex64, reflect.Complex128:
+					// Not implemented
+					return ret, fmt.Errorf("Complex not implemented")
+				}
+
+				return ret, fmt.Errorf("Invalid argument to «%s». Got %s, expected %s", n, argv.String(), inType.String())
+			}
+
+			in = append(in, argv.Convert(inType))
+		}
+
+		var err error = nil
+		out := fn.Call(in)[0:fnt.NumOut()]
+
+		for _, val := range out {
+			switch val.Kind() {
+			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+				ret = append(ret, val.Uint())
+			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+				ret = append(ret, val.Int())
+			case reflect.Float32, reflect.Float64:
+				ret = append(ret, val.Float())
+			case reflect.String:
+				ret = append(ret, val.String())
+			case reflect.Interface:
+				if !val.IsNil() && val.CanInterface() && val.Type().Implements(reflect.TypeOf((*error)(nil)).Elem()) {
+					err = val.Interface().(error)
+				}
+			}
+		}
+
+		return ret, err
+
+	}
+
+	return ret, fmt.Errorf("Invalid method «%s» on «%s»", n, t.getType())
+}
+
 // getPrimaryKey tries to find primary key
 func (t *table) getPrimaryKey() *column {
-	var pkey *column
+	var pkey *column = nil
+
 	if c := t.hasTaggedColumn("primary"); c != nil {
 		pkey = c
 	} else if c := t.hasColumn("Id"); c != nil {
 		pkey = c
 	}
+
 	return pkey
 }
 
+func (t *table) getRelations() map[relType][]relation {
+	return t.rels.rmap
+}
+
+func (t *table) getColumns() columns {
+	return t.cols
+}
+
 // hasTaggetColumn checks for a collumn tagget as
 func (t *table) hasTaggedColumn(ct string) *column {
 	for _, col := range t.cols {
@@ -66,7 +163,7 @@ func (t *table) addField(f field, cbCh chan<- mapperCallback) {
 
 			fn: func(self *table) {
 				// Get Primary Key
-				pkey := t.getPrimaryKey()
+				pkey := self.getPrimaryKey()
 
 				// Or at least try to
 				if pkey == nil {
@@ -89,6 +186,12 @@ func (t *table) addField(f field, cbCh chan<- mapperCallback) {
 					cn = tag
 				}
 
+				/*
+					if self.getStructName() == "PlatformCalendar" {
+						fmt.Println(self.getStructName(), f.getFieldType(), cn)
+					}
+				*/
+
 				// Check if it contains reference itself
 				if c := self.hasColumn(cn); c != nil {
 					// Make a call to load related table into scope;
@@ -169,6 +272,8 @@ func (t *table) addField(f field, cbCh chan<- mapperCallback) {
 			t.cols = append(t.cols, column{
 				f, dbf,
 			})
+		default:
+			fmt.Println(t.getStructName(), "not supporting", f)
 		}
 	}
 }
@@ -179,8 +284,11 @@ func (t *table) getType() reflect.Type {
 }
 
 // getValue returns the reflect.Value of the «table»
-func (t *table) getValue() reflect.Value {
-	return t.rv
+func (t *table) getValue(ptr ...bool) reflect.Value {
+	if len(ptr) > 0 && ptr[0] {
+		return t.rv
+	}
+	return t.rv.Elem()
 }
 
 // isMapped returns true when columns is mapped
@@ -218,9 +326,24 @@ func (t *table) getStructName() string {
 
 // getName returns the mapped table name
 // as identified in the database
-func (t *table) getName() string {
+func (t *table) getName(q bool) string {
+	if q {
+		return SqlFlavor.Quote(t.getName(false))
+	}
 	return t.tFn(t.getType().Name())
 }
 
+func (t *table) getNameAs(q bool) string {
+	return t.getName(q) + " AS " + t.getAlias(q)
+}
+
+func (t *table) getAlias(q bool) string {
+	if q {
+		return SqlFlavor.Quote(t.getAlias(false))
+	}
+
+	return t.getStructName()
+}
+
 // tables is simply a collection of tables
 type tables map[string]*table