1
2
3
4
5
6
7 package sql
8
9 import (
10 "database/sql/driver"
11 "errors"
12 "fmt"
13 "reflect"
14 "strconv"
15 "time"
16 "unicode"
17 "unicode/utf8"
18 )
19
20 var errNilPtr = errors.New("destination pointer is nil")
21
22 func describeNamedValue(nv *driver.NamedValue) string {
23 if len(nv.Name) == 0 {
24 return fmt.Sprintf("$%d", nv.Ordinal)
25 }
26 return fmt.Sprintf("with name %q", nv.Name)
27 }
28
29 func validateNamedValueName(name string) error {
30 if len(name) == 0 {
31 return nil
32 }
33 r, _ := utf8.DecodeRuneInString(name)
34 if unicode.IsLetter(r) {
35 return nil
36 }
37 return fmt.Errorf("name %q does not begin with a letter", name)
38 }
39
40
41
42
43 type ccChecker struct {
44 cci driver.ColumnConverter
45 want int
46 }
47
48 func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
49 if c.cci == nil {
50 return driver.ErrSkip
51 }
52
53
54
55 index := nv.Ordinal - 1
56 if c.want <= index {
57 return nil
58 }
59
60
61
62
63 if vr, ok := nv.Value.(driver.Valuer); ok {
64 sv, err := callValuerValue(vr)
65 if err != nil {
66 return err
67 }
68 if !driver.IsValue(sv) {
69 return fmt.Errorf("non-subset type %T returned from Value", sv)
70 }
71 nv.Value = sv
72 }
73
74
75
76
77
78
79
80
81 var err error
82 arg := nv.Value
83 nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
84 if err != nil {
85 return err
86 }
87 if !driver.IsValue(nv.Value) {
88 return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
89 }
90 return nil
91 }
92
93
94
95
96 func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
97 nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
98 return err
99 }
100
101
102
103
104
105 func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
106 nvargs := make([]driver.NamedValue, len(args))
107
108
109
110
111 want := -1
112
113 var si driver.Stmt
114 var cc ccChecker
115 if ds != nil {
116 si = ds.si
117 want = ds.si.NumInput()
118 cc.want = want
119 }
120
121
122
123
124
125 nvc, ok := si.(driver.NamedValueChecker)
126 if !ok {
127 nvc, ok = ci.(driver.NamedValueChecker)
128 }
129 cci, ok := si.(driver.ColumnConverter)
130 if ok {
131 cc.cci = cci
132 }
133
134
135
136
137
138
139 var err error
140 var n int
141 for _, arg := range args {
142 nv := &nvargs[n]
143 if np, ok := arg.(NamedArg); ok {
144 if err = validateNamedValueName(np.Name); err != nil {
145 return nil, err
146 }
147 arg = np.Value
148 nv.Name = np.Name
149 }
150 nv.Ordinal = n + 1
151 nv.Value = arg
152
153
154
155
156
157
158
159
160
161
162
163
164 checker := defaultCheckNamedValue
165 nextCC := false
166 switch {
167 case nvc != nil:
168 nextCC = cci != nil
169 checker = nvc.CheckNamedValue
170 case cci != nil:
171 checker = cc.CheckNamedValue
172 }
173
174 nextCheck:
175 err = checker(nv)
176 switch err {
177 case nil:
178 n++
179 continue
180 case driver.ErrRemoveArgument:
181 nvargs = nvargs[:len(nvargs)-1]
182 continue
183 case driver.ErrSkip:
184 if nextCC {
185 nextCC = false
186 checker = cc.CheckNamedValue
187 } else {
188 checker = defaultCheckNamedValue
189 }
190 goto nextCheck
191 default:
192 return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err)
193 }
194 }
195
196
197
198 if want != -1 && len(nvargs) != want {
199 return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
200 }
201
202 return nvargs, nil
203
204 }
205
206
207
208
209 func convertAssign(dest, src interface{}) error {
210
211 switch s := src.(type) {
212 case string:
213 switch d := dest.(type) {
214 case *string:
215 if d == nil {
216 return errNilPtr
217 }
218 *d = s
219 return nil
220 case *[]byte:
221 if d == nil {
222 return errNilPtr
223 }
224 *d = []byte(s)
225 return nil
226 case *RawBytes:
227 if d == nil {
228 return errNilPtr
229 }
230 *d = append((*d)[:0], s...)
231 return nil
232 }
233 case []byte:
234 switch d := dest.(type) {
235 case *string:
236 if d == nil {
237 return errNilPtr
238 }
239 *d = string(s)
240 return nil
241 case *interface{}:
242 if d == nil {
243 return errNilPtr
244 }
245 *d = cloneBytes(s)
246 return nil
247 case *[]byte:
248 if d == nil {
249 return errNilPtr
250 }
251 *d = cloneBytes(s)
252 return nil
253 case *RawBytes:
254 if d == nil {
255 return errNilPtr
256 }
257 *d = s
258 return nil
259 }
260 case time.Time:
261 switch d := dest.(type) {
262 case *time.Time:
263 *d = s
264 return nil
265 case *string:
266 *d = s.Format(time.RFC3339Nano)
267 return nil
268 case *[]byte:
269 if d == nil {
270 return errNilPtr
271 }
272 *d = []byte(s.Format(time.RFC3339Nano))
273 return nil
274 case *RawBytes:
275 if d == nil {
276 return errNilPtr
277 }
278 *d = s.AppendFormat((*d)[:0], time.RFC3339Nano)
279 return nil
280 }
281 case nil:
282 switch d := dest.(type) {
283 case *interface{}:
284 if d == nil {
285 return errNilPtr
286 }
287 *d = nil
288 return nil
289 case *[]byte:
290 if d == nil {
291 return errNilPtr
292 }
293 *d = nil
294 return nil
295 case *RawBytes:
296 if d == nil {
297 return errNilPtr
298 }
299 *d = nil
300 return nil
301 }
302 }
303
304 var sv reflect.Value
305
306 switch d := dest.(type) {
307 case *string:
308 sv = reflect.ValueOf(src)
309 switch sv.Kind() {
310 case reflect.Bool,
311 reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
312 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
313 reflect.Float32, reflect.Float64:
314 *d = asString(src)
315 return nil
316 }
317 case *[]byte:
318 sv = reflect.ValueOf(src)
319 if b, ok := asBytes(nil, sv); ok {
320 *d = b
321 return nil
322 }
323 case *RawBytes:
324 sv = reflect.ValueOf(src)
325 if b, ok := asBytes([]byte(*d)[:0], sv); ok {
326 *d = RawBytes(b)
327 return nil
328 }
329 case *bool:
330 bv, err := driver.Bool.ConvertValue(src)
331 if err == nil {
332 *d = bv.(bool)
333 }
334 return err
335 case *interface{}:
336 *d = src
337 return nil
338 }
339
340 if scanner, ok := dest.(Scanner); ok {
341 return scanner.Scan(src)
342 }
343
344 dpv := reflect.ValueOf(dest)
345 if dpv.Kind() != reflect.Ptr {
346 return errors.New("destination not a pointer")
347 }
348 if dpv.IsNil() {
349 return errNilPtr
350 }
351
352 if !sv.IsValid() {
353 sv = reflect.ValueOf(src)
354 }
355
356 dv := reflect.Indirect(dpv)
357 if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
358 switch b := src.(type) {
359 case []byte:
360 dv.Set(reflect.ValueOf(cloneBytes(b)))
361 default:
362 dv.Set(sv)
363 }
364 return nil
365 }
366
367 if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
368 dv.Set(sv.Convert(dv.Type()))
369 return nil
370 }
371
372
373
374
375
376
377 switch dv.Kind() {
378 case reflect.Ptr:
379 if src == nil {
380 dv.Set(reflect.Zero(dv.Type()))
381 return nil
382 } else {
383 dv.Set(reflect.New(dv.Type().Elem()))
384 return convertAssign(dv.Interface(), src)
385 }
386 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
387 s := asString(src)
388 i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
389 if err != nil {
390 err = strconvErr(err)
391 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
392 }
393 dv.SetInt(i64)
394 return nil
395 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
396 s := asString(src)
397 u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
398 if err != nil {
399 err = strconvErr(err)
400 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
401 }
402 dv.SetUint(u64)
403 return nil
404 case reflect.Float32, reflect.Float64:
405 s := asString(src)
406 f64, err := strconv.ParseFloat(s, dv.Type().Bits())
407 if err != nil {
408 err = strconvErr(err)
409 return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
410 }
411 dv.SetFloat(f64)
412 return nil
413 case reflect.String:
414 switch v := src.(type) {
415 case string:
416 dv.SetString(v)
417 return nil
418 case []byte:
419 dv.SetString(string(v))
420 return nil
421 }
422 }
423
424 return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
425 }
426
427 func strconvErr(err error) error {
428 if ne, ok := err.(*strconv.NumError); ok {
429 return ne.Err
430 }
431 return err
432 }
433
434 func cloneBytes(b []byte) []byte {
435 if b == nil {
436 return nil
437 } else {
438 c := make([]byte, len(b))
439 copy(c, b)
440 return c
441 }
442 }
443
444 func asString(src interface{}) string {
445 switch v := src.(type) {
446 case string:
447 return v
448 case []byte:
449 return string(v)
450 }
451 rv := reflect.ValueOf(src)
452 switch rv.Kind() {
453 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
454 return strconv.FormatInt(rv.Int(), 10)
455 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
456 return strconv.FormatUint(rv.Uint(), 10)
457 case reflect.Float64:
458 return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
459 case reflect.Float32:
460 return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
461 case reflect.Bool:
462 return strconv.FormatBool(rv.Bool())
463 }
464 return fmt.Sprintf("%v", src)
465 }
466
467 func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
468 switch rv.Kind() {
469 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
470 return strconv.AppendInt(buf, rv.Int(), 10), true
471 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
472 return strconv.AppendUint(buf, rv.Uint(), 10), true
473 case reflect.Float32:
474 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
475 case reflect.Float64:
476 return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
477 case reflect.Bool:
478 return strconv.AppendBool(buf, rv.Bool()), true
479 case reflect.String:
480 s := rv.String()
481 return append(buf, s...), true
482 }
483 return
484 }
485
486 var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
487
488
489
490
491
492
493
494
495
496
497
498
499 func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
500 if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
501 rv.IsNil() &&
502 rv.Type().Elem().Implements(valuerReflectType) {
503 return nil, nil
504 }
505 return vr.Value()
506 }
507
View as plain text