Skip to content

Commit c7c0d6d

Browse files
committed
feat: add null type
1 parent 1399d3c commit c7c0d6d

File tree

4 files changed

+5981
-0
lines changed

4 files changed

+5981
-0
lines changed

convert.go

Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
// Copyright 2011 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
// Type conversions for Scan.
6+
7+
package datatypes
8+
9+
import (
10+
"bytes"
11+
"database/sql"
12+
"database/sql/driver"
13+
"errors"
14+
"fmt"
15+
"reflect"
16+
"strconv"
17+
"time"
18+
)
19+
20+
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
21+
22+
// convertAssign is the same as convertAssignRows, but without the optional
23+
// rows argument.
24+
func convertAssign(dest, src any) error {
25+
return convertAssignRows(dest, src, nil)
26+
}
27+
28+
// convertAssignRows copies to dest the value in src, converting it if possible.
29+
// An error is returned if the copy would result in loss of information.
30+
// dest should be a pointer type. If rows is passed in, the rows will
31+
// be used as the parent for any cursor values converted from a
32+
// driver.Rows to a *Rows.
33+
func convertAssignRows(dest, src any, rows *sql.Rows) error {
34+
// Common cases, without reflect.
35+
switch s := src.(type) {
36+
case string:
37+
switch d := dest.(type) {
38+
case *string:
39+
if d == nil {
40+
return errNilPtr
41+
}
42+
*d = s
43+
return nil
44+
case *[]byte:
45+
if d == nil {
46+
return errNilPtr
47+
}
48+
*d = []byte(s)
49+
return nil
50+
case *RawBytes:
51+
if d == nil {
52+
return errNilPtr
53+
}
54+
*d = append((*d)[:0], s...)
55+
return nil
56+
}
57+
case []byte:
58+
switch d := dest.(type) {
59+
case *string:
60+
if d == nil {
61+
return errNilPtr
62+
}
63+
*d = string(s)
64+
return nil
65+
case *any:
66+
if d == nil {
67+
return errNilPtr
68+
}
69+
*d = bytes.Clone(s)
70+
return nil
71+
case *[]byte:
72+
if d == nil {
73+
return errNilPtr
74+
}
75+
*d = bytes.Clone(s)
76+
return nil
77+
case *RawBytes:
78+
if d == nil {
79+
return errNilPtr
80+
}
81+
*d = s
82+
return nil
83+
}
84+
case time.Time:
85+
switch d := dest.(type) {
86+
case *time.Time:
87+
*d = s
88+
return nil
89+
case *string:
90+
*d = s.Format(time.RFC3339Nano)
91+
return nil
92+
case *[]byte:
93+
if d == nil {
94+
return errNilPtr
95+
}
96+
*d = []byte(s.Format(time.RFC3339Nano))
97+
return nil
98+
case *RawBytes:
99+
if d == nil {
100+
return errNilPtr
101+
}
102+
*d = s.AppendFormat((*d)[:0], time.RFC3339Nano)
103+
return nil
104+
}
105+
case decimalDecompose:
106+
switch d := dest.(type) {
107+
case decimalCompose:
108+
return d.Compose(s.Decompose(nil))
109+
}
110+
case nil:
111+
switch d := dest.(type) {
112+
case *any:
113+
if d == nil {
114+
return errNilPtr
115+
}
116+
*d = nil
117+
return nil
118+
case *[]byte:
119+
if d == nil {
120+
return errNilPtr
121+
}
122+
*d = nil
123+
return nil
124+
case *RawBytes:
125+
if d == nil {
126+
return errNilPtr
127+
}
128+
*d = nil
129+
return nil
130+
}
131+
// The driver is returning a cursor the client may iterate over.
132+
case driver.Rows:
133+
switch d := dest.(type) {
134+
case *sql.Rows:
135+
if d == nil {
136+
return errNilPtr
137+
}
138+
if rows == nil {
139+
return errors.New("invalid context to convert cursor rows, missing parent *Rows")
140+
}
141+
rows.closemu.Lock()
142+
*d = sql.Rows{
143+
dc: rows.dc,
144+
releaseConn: func(error) {},
145+
rowsi: s,
146+
}
147+
// Chain the cancel function.
148+
parentCancel := rows.cancel
149+
rows.cancel = func() {
150+
// When Rows.cancel is called, the closemu will be locked as well.
151+
// So we can access rs.lasterr.
152+
d.close(rows.lasterr)
153+
if parentCancel != nil {
154+
parentCancel()
155+
}
156+
}
157+
rows.closemu.Unlock()
158+
return nil
159+
}
160+
}
161+
162+
var sv reflect.Value
163+
164+
switch d := dest.(type) {
165+
case *string:
166+
sv = reflect.ValueOf(src)
167+
switch sv.Kind() {
168+
case reflect.Bool,
169+
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
170+
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
171+
reflect.Float32, reflect.Float64:
172+
*d = asString(src)
173+
return nil
174+
}
175+
case *[]byte:
176+
sv = reflect.ValueOf(src)
177+
if b, ok := asBytes(nil, sv); ok {
178+
*d = b
179+
return nil
180+
}
181+
case *RawBytes:
182+
sv = reflect.ValueOf(src)
183+
if b, ok := asBytes([]byte(*d)[:0], sv); ok {
184+
*d = RawBytes(b)
185+
return nil
186+
}
187+
case *bool:
188+
bv, err := driver.Bool.ConvertValue(src)
189+
if err == nil {
190+
*d = bv.(bool)
191+
}
192+
return err
193+
case *any:
194+
*d = src
195+
return nil
196+
}
197+
198+
if scanner, ok := dest.(sql.Scanner); ok {
199+
return scanner.Scan(src)
200+
}
201+
202+
dpv := reflect.ValueOf(dest)
203+
if dpv.Kind() != reflect.Pointer {
204+
return errors.New("destination not a pointer")
205+
}
206+
if dpv.IsNil() {
207+
return errNilPtr
208+
}
209+
210+
if !sv.IsValid() {
211+
sv = reflect.ValueOf(src)
212+
}
213+
214+
dv := reflect.Indirect(dpv)
215+
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
216+
switch b := src.(type) {
217+
case []byte:
218+
dv.Set(reflect.ValueOf(bytes.Clone(b)))
219+
default:
220+
dv.Set(sv)
221+
}
222+
return nil
223+
}
224+
225+
if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
226+
dv.Set(sv.Convert(dv.Type()))
227+
return nil
228+
}
229+
230+
// The following conversions use a string value as an intermediate representation
231+
// to convert between various numeric types.
232+
//
233+
// This also allows scanning into user defined types such as "type Int int64".
234+
// For symmetry, also check for string destination types.
235+
switch dv.Kind() {
236+
case reflect.Pointer:
237+
if src == nil {
238+
dv.SetZero()
239+
return nil
240+
}
241+
dv.Set(reflect.New(dv.Type().Elem()))
242+
return convertAssignRows(dv.Interface(), src, rows)
243+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
244+
if src == nil {
245+
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
246+
}
247+
s := asString(src)
248+
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
249+
if err != nil {
250+
err = strconvErr(err)
251+
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
252+
}
253+
dv.SetInt(i64)
254+
return nil
255+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
256+
if src == nil {
257+
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
258+
}
259+
s := asString(src)
260+
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
261+
if err != nil {
262+
err = strconvErr(err)
263+
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
264+
}
265+
dv.SetUint(u64)
266+
return nil
267+
case reflect.Float32, reflect.Float64:
268+
if src == nil {
269+
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
270+
}
271+
s := asString(src)
272+
f64, err := strconv.ParseFloat(s, dv.Type().Bits())
273+
if err != nil {
274+
err = strconvErr(err)
275+
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
276+
}
277+
dv.SetFloat(f64)
278+
return nil
279+
case reflect.String:
280+
if src == nil {
281+
return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
282+
}
283+
switch v := src.(type) {
284+
case string:
285+
dv.SetString(v)
286+
return nil
287+
case []byte:
288+
dv.SetString(string(v))
289+
return nil
290+
}
291+
}
292+
293+
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
294+
}
295+
296+
func strconvErr(err error) error {
297+
if ne, ok := err.(*strconv.NumError); ok {
298+
return ne.Err
299+
}
300+
return err
301+
}
302+
303+
func asString(src any) string {
304+
switch v := src.(type) {
305+
case string:
306+
return v
307+
case []byte:
308+
return string(v)
309+
}
310+
rv := reflect.ValueOf(src)
311+
switch rv.Kind() {
312+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
313+
return strconv.FormatInt(rv.Int(), 10)
314+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
315+
return strconv.FormatUint(rv.Uint(), 10)
316+
case reflect.Float64:
317+
return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
318+
case reflect.Float32:
319+
return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
320+
case reflect.Bool:
321+
return strconv.FormatBool(rv.Bool())
322+
}
323+
return fmt.Sprintf("%v", src)
324+
}
325+
326+
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
327+
switch rv.Kind() {
328+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
329+
return strconv.AppendInt(buf, rv.Int(), 10), true
330+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
331+
return strconv.AppendUint(buf, rv.Uint(), 10), true
332+
case reflect.Float32:
333+
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
334+
case reflect.Float64:
335+
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
336+
case reflect.Bool:
337+
return strconv.AppendBool(buf, rv.Bool()), true
338+
case reflect.String:
339+
s := rv.String()
340+
return append(buf, s...), true
341+
}
342+
return
343+
}
344+
345+
type decimalDecompose interface {
346+
// Decompose returns the internal decimal state in parts.
347+
// If the provided buf has sufficient capacity, buf may be returned as the coefficient with
348+
// the value set and length set as appropriate.
349+
Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
350+
}
351+
352+
type decimalCompose interface {
353+
// Compose sets the internal decimal value from parts. If the value cannot be
354+
// represented then an error should be returned.
355+
Compose(form byte, negative bool, coefficient []byte, exponent int32) error
356+
}

0 commit comments

Comments
 (0)