iter.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. package sqlite
  2. import (
  3. "context"
  4. "database/sql"
  5. "reflect"
  6. "github.com/wecisecode/util/merrs"
  7. )
  8. type Row struct {
  9. data map[string]interface{}
  10. }
  11. func (rr *Row) Data() map[string]interface{} {
  12. return rr.data
  13. }
  14. // 迭代器
  15. type Iter struct {
  16. ctx context.Context
  17. rows *sql.Rows
  18. cols []*sql.ColumnType
  19. }
  20. // 获取下一数据
  21. func (iter *Iter) NextRow() (row *Row, err error) {
  22. if iter.ctx != nil {
  23. err = iter.ctx.Err()
  24. if err != nil {
  25. return
  26. }
  27. }
  28. if !iter.rows.Next() {
  29. iter.Close()
  30. return
  31. }
  32. ftypes := make([]string, len(iter.cols))
  33. values := make([]interface{}, len(iter.cols))
  34. for i := range values {
  35. coltype := iter.cols[i]
  36. scantype := coltype.ScanType()
  37. ftypes[i] = coltype.DatabaseTypeName()
  38. if scantype == nil {
  39. bs := []byte{}
  40. values[i] = &bs
  41. } else {
  42. v := reflect.New(scantype).Interface()
  43. values[i] = &v
  44. }
  45. }
  46. if err = iter.rows.Scan(values...); err != nil {
  47. err = merrs.NormalError.NewCause(err)
  48. iter.Close()
  49. return
  50. }
  51. m := make(map[string]interface{})
  52. for i, v := range values {
  53. fieldname := iter.cols[i].Name()
  54. ftype := ftypes[i]
  55. if iter.cols[i].ScanType() == nil {
  56. v = nil
  57. } else {
  58. v = reflect.ValueOf(v).Elem().Interface()
  59. v = SQLValueDecode(ftype, v)
  60. }
  61. m[fieldname] = v
  62. }
  63. row = &Row{m}
  64. return
  65. }
  66. // 关闭迭代器
  67. func (iter *Iter) Close() (err error) {
  68. return iter.rows.Close()
  69. }
  70. // 一次获取多行数据
  71. func (iter *Iter) NextRows(n int) (rows []*Row, err error) {
  72. rows = []*Row{}
  73. for {
  74. rd, e := iter.NextRow()
  75. if e != nil {
  76. err = merrs.NewError(e)
  77. return
  78. }
  79. if rd == nil || n > 0 && len(rows) >= n {
  80. return
  81. }
  82. rows = append(rows, rd)
  83. }
  84. }
  85. // 一次获取多行map形式的数据
  86. func (iter *Iter) NextMaps(n int) (rds []map[string]interface{}, err error) {
  87. rds = []map[string]interface{}{}
  88. for {
  89. rd, e := iter.NextRow()
  90. if e != nil {
  91. err = merrs.NormalError.NewCause(e)
  92. return
  93. }
  94. if rd == nil || n > 0 && len(rds) >= n {
  95. return
  96. }
  97. rds = append(rds, rd.Data())
  98. }
  99. }
  100. // 一次获取所有数据
  101. func (iter *Iter) AllRows() (rows []*Row, err error) {
  102. defer iter.Close()
  103. return iter.NextRows(-1)
  104. }
  105. // 一次获取所有map形式的数据
  106. func (iter *Iter) AllMaps() (rds []map[string]interface{}, err error) {
  107. defer iter.Close()
  108. return iter.NextMaps(-1)
  109. }