-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsql.go
132 lines (108 loc) · 3.24 KB
/
sql.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package sql_util
import (
"database/sql"
"fmt"
"reflect"
"strings"
"sync"
)
const (
TAG_NAME = "db"
TAG_SEPARATOR = ","
)
var (
GoSqlUtil = &SqlUtil{
cachedStructMaps: make(map[reflect.Type]*structInfo),
}
)
type SqlUtil struct {
cachedStructMaps map[reflect.Type]*structInfo
structsCacheMutex sync.Mutex
}
type structField struct {
DBColumnName string
FieldName string
Index int
}
type structInfo struct {
structFields map[string]*structField
}
func (s *SqlUtil) Scan(rows *sql.Rows, structs ...interface{}) error {
return scan(rows, structs...)
}
func scan(rows *sql.Rows, structs ...interface{}) error {
if err := rows.Err(); err != nil {
return err
}
columns, err := rows.Columns()
if err != nil {
return err
}
targets, err := findScanTargets(columns, structs...)
if err != nil {
return err
}
rows.Scan(targets...)
return nil
}
func findScanTargets(columns []string, structs ...interface{}) ([]interface{}, error) {
var targets []interface{}
for _, columnName := range columns {
for _, dst := range structs {
data, err := getStructFields(reflect.TypeOf(dst))
structVal := reflect.ValueOf(dst).Elem()
if err != nil {
return nil, err
}
if field, present := data.(*structInfo).structFields[columnName]; present {
fieldAddr := structVal.Field(field.Index).Addr().Interface()
targets = append(targets, fieldAddr)
break
}
}
}
return targets, nil
}
func getStructFields(destinationType reflect.Type) (interface{}, error) {
if data, present := GoSqlUtil.cachedStructMaps[destinationType]; present {
return data, nil
}
if destinationType.Kind() != reflect.Ptr {
return nil, fmt.Errorf("SQL Util called with non-pointer destination %v", destinationType)
}
structType := destinationType.Elem()
if structType.Kind() != reflect.Struct {
return nil, fmt.Errorf("SQL Util called with pointer to non-struct %v", destinationType)
}
data := new(structInfo)
data.structFields = make(map[string]*structField)
numStructFields := structType.NumField()
for i := 0; i < numStructFields; i++ {
f := structType.Field(i)
if f.PkgPath != "" {
continue
}
if f.Type.Kind() == reflect.Ptr && f.Type.Elem().Kind() == reflect.Struct {
// TODO - Add functionality to handle embedded structs
continue
} else {
columnName := f.Name
tags := strings.Split(f.Tag.Get(TAG_NAME), TAG_SEPARATOR)
if len(tags) > 0 && tags[0] == "-" {
continue
}
if len(tags) > 0 && tags[0] != "" {
columnName = tags[0]
}
data.structFields[columnName] = &structField{
DBColumnName: columnName,
FieldName: f.Name,
Index: i,
}
}
}
GoSqlUtil.structsCacheMutex.Lock()
defer GoSqlUtil.structsCacheMutex.Unlock()
GoSqlUtil.cachedStructMaps[destinationType] = data
return data, nil
}