repository.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. package orm
  2. import (
  3. "fmt"
  4. "reflect"
  5. "strings"
  6. "git.giaever.org/bnb.hosting/orm/conn"
  7. "github.com/huandu/go-sqlbuilder"
  8. "github.com/jmoiron/sqlx"
  9. )
  10. type repository struct {
  11. dbc map[string]MappableInterface
  12. }
  13. var r *repository = nil
  14. func Repository() *repository {
  15. if r == nil {
  16. r = &repository{
  17. make(map[string]MappableInterface),
  18. }
  19. }
  20. return r
  21. }
  22. func (r *repository) FetchFirst(db *conn.DB, i MappableInterface) (MappableInterface, error) {
  23. res, err := r.Fetch(db, i, 1)
  24. if len(res) == 1 {
  25. return res[0], err
  26. }
  27. return nil, err
  28. }
  29. func (r *repository) Fetch(db *conn.DB, i MappableInterface, offlim ...int) ([]MappableInterface, error) {
  30. tbl := ctx.getTbl(i)
  31. b := getSelectBuilder(tbl)._select(tbl)._where(tbl, i)
  32. if len(offlim) == 1 {
  33. b.Limit(offlim[0])
  34. } else if len(offlim) > 1 {
  35. b.Limit(offlim[1])
  36. b.Offset(offlim[0])
  37. }
  38. return r.query(db, tbl, b)
  39. }
  40. func (r *repository) FetchRelated(i MappableInterface, n string) (interface{}, error) {
  41. tbl := ctx.getTbl(i)
  42. for rtype, rels := range tbl.getRelations() {
  43. for _, rel := range rels {
  44. if rel.f.getFieldName() == n {
  45. switch rtype {
  46. case hasMany, belongsTo:
  47. b := getSelectBuilder(tbl)
  48. b._select(rel.table)
  49. b._join(tbl, rel)
  50. b._wherePrimaryOrElse(tbl, i)
  51. switch rel.f.MakeType().Kind() {
  52. case reflect.Map, reflect.Slice:
  53. // No limit?
  54. default:
  55. b.Limit(1)
  56. }
  57. res, err := r.query(i.GetDB(), rel.table, b)
  58. if err != nil {
  59. return nil, err
  60. }
  61. return r.mapScanResultToInterface(i, tbl, rel, res), nil
  62. case hasOne:
  63. b := getSelectBuilder(tbl)
  64. b._select(rel.table)
  65. b._join(tbl, rel, sqlbuilder.RightJoin)
  66. b._wherePrimaryOrElse(tbl, i)
  67. b.Limit(1)
  68. res, err := r.query(i.GetDB(), rel.table, b)
  69. if err != nil || len(res) == 0 {
  70. return nil, err
  71. }
  72. return res[0], nil
  73. }
  74. }
  75. }
  76. }
  77. return nil, nil
  78. }
  79. func (r *repository) mapScanResultToInterface(i MappableInterface, tbl *table, rel relation, res []MappableInterface) interface{} {
  80. m := rel.f.Make()
  81. switch m.Kind() {
  82. case reflect.Map:
  83. for midx, rel := range res {
  84. if val, err := tbl.CallMethod(i, "GetId"); err != nil {
  85. switch val[0].(type) {
  86. case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64:
  87. m.SetMapIndex(reflect.ValueOf(val[0].(uint64)), reflect.ValueOf(rel))
  88. }
  89. } else {
  90. m.SetMapIndex(reflect.ValueOf(uint64(midx)), reflect.ValueOf(rel))
  91. }
  92. }
  93. case reflect.Slice:
  94. for _, rel := range res {
  95. m = reflect.Append(m, reflect.ValueOf(rel))
  96. }
  97. default:
  98. if len(res) != 0 {
  99. m.Set(reflect.ValueOf(res[0]))
  100. }
  101. }
  102. return m.Interface()
  103. }
  104. func (r *repository) mapScan(i MappableInterface, rows *sqlx.Rows, tbl *table) (MappableInterface, error) {
  105. res := make(map[string]interface{})
  106. if err := rows.MapScan(res); err != nil {
  107. fmt.Println("Error:", err)
  108. return nil, err
  109. }
  110. for n, s := range res {
  111. if fn := tbl.hasColumn(n).getFieldName(); len(fn) != 0 {
  112. if _, err := tbl.CallMethod(i, "Set"+strings.Title(fn), s); err != nil {
  113. fmt.Println("ERROR", err)
  114. //return nil, err
  115. }
  116. } else {
  117. fmt.Println("MISSING FIELD NAME")
  118. }
  119. }
  120. if col := tbl.getPrimaryKey(); col != nil {
  121. if v, err := tbl.CallMethod(i, "Get"+strings.Title(col.getFieldName())); err == nil {
  122. idx := fmt.Sprintf("%s_%#v", tbl.getStructName(), v[0])
  123. if ni, ok := r.dbc[idx]; ok {
  124. return ni, nil
  125. } else {
  126. r.dbc[idx] = i
  127. }
  128. }
  129. }
  130. return i, nil
  131. }
  132. func (r *repository) query(db *conn.DB, tbl *table, sql sqlbuilder.Builder) ([]MappableInterface, error) {
  133. var ret []MappableInterface
  134. q, args := sql.Build()
  135. rows, err := db.Queryx(q, args...)
  136. for err != nil {
  137. return nil, err
  138. }
  139. //defer rows.Close()
  140. for rows.Next() {
  141. i, err := r.mapScan(tbl.Make(), rows, tbl)
  142. if err == nil {
  143. i.SetDB(db)
  144. ret = append(ret, i)
  145. } else {
  146. fmt.Println(err)
  147. return nil, err
  148. }
  149. }
  150. return ret, nil
  151. }