1
2
3
4
5 package template
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "net/url"
13 "reflect"
14 "strings"
15 "unicode"
16 "unicode/utf8"
17 )
18
19
20
21
22
23
24
25
26
27
28
29
30 type FuncMap map[string]interface{}
31
32 var builtins = FuncMap{
33 "and": and,
34 "call": call,
35 "html": HTMLEscaper,
36 "index": index,
37 "js": JSEscaper,
38 "len": length,
39 "not": not,
40 "or": or,
41 "print": fmt.Sprint,
42 "printf": fmt.Sprintf,
43 "println": fmt.Sprintln,
44 "urlquery": URLQueryEscaper,
45
46
47 "eq": eq,
48 "ge": ge,
49 "gt": gt,
50 "le": le,
51 "lt": lt,
52 "ne": ne,
53 }
54
55 var builtinFuncs = createValueFuncs(builtins)
56
57
58 func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
59 m := make(map[string]reflect.Value)
60 addValueFuncs(m, funcMap)
61 return m
62 }
63
64
65 func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
66 for name, fn := range in {
67 if !goodName(name) {
68 panic(fmt.Errorf("function name %s is not a valid identifier", name))
69 }
70 v := reflect.ValueOf(fn)
71 if v.Kind() != reflect.Func {
72 panic("value for " + name + " not a function")
73 }
74 if !goodFunc(v.Type()) {
75 panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
76 }
77 out[name] = v
78 }
79 }
80
81
82
83 func addFuncs(out, in FuncMap) {
84 for name, fn := range in {
85 out[name] = fn
86 }
87 }
88
89
90 func goodFunc(typ reflect.Type) bool {
91
92 switch {
93 case typ.NumOut() == 1:
94 return true
95 case typ.NumOut() == 2 && typ.Out(1) == errorType:
96 return true
97 }
98 return false
99 }
100
101
102 func goodName(name string) bool {
103 if name == "" {
104 return false
105 }
106 for i, r := range name {
107 switch {
108 case r == '_':
109 case i == 0 && !unicode.IsLetter(r):
110 return false
111 case !unicode.IsLetter(r) && !unicode.IsDigit(r):
112 return false
113 }
114 }
115 return true
116 }
117
118
119 func findFunction(name string, tmpl *Template) (reflect.Value, bool) {
120 if tmpl != nil && tmpl.common != nil {
121 tmpl.muFuncs.RLock()
122 defer tmpl.muFuncs.RUnlock()
123 if fn := tmpl.execFuncs[name]; fn.IsValid() {
124 return fn, true
125 }
126 }
127 if fn := builtinFuncs[name]; fn.IsValid() {
128 return fn, true
129 }
130 return reflect.Value{}, false
131 }
132
133
134
135 func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
136 if !value.IsValid() {
137 if !canBeNil(argType) {
138 return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
139 }
140 value = reflect.Zero(argType)
141 }
142 if value.Type().AssignableTo(argType) {
143 return value, nil
144 }
145 if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
146 value = value.Convert(argType)
147 return value, nil
148 }
149 return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
150 }
151
152 func intLike(typ reflect.Kind) bool {
153 switch typ {
154 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
155 return true
156 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
157 return true
158 }
159 return false
160 }
161
162
163
164
165
166
167 func index(item reflect.Value, indices ...reflect.Value) (reflect.Value, error) {
168 v := indirectInterface(item)
169 if !v.IsValid() {
170 return reflect.Value{}, fmt.Errorf("index of untyped nil")
171 }
172 for _, i := range indices {
173 index := indirectInterface(i)
174 var isNil bool
175 if v, isNil = indirect(v); isNil {
176 return reflect.Value{}, fmt.Errorf("index of nil pointer")
177 }
178 switch v.Kind() {
179 case reflect.Array, reflect.Slice, reflect.String:
180 var x int64
181 switch index.Kind() {
182 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
183 x = index.Int()
184 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
185 x = int64(index.Uint())
186 case reflect.Invalid:
187 return reflect.Value{}, fmt.Errorf("cannot index slice/array with nil")
188 default:
189 return reflect.Value{}, fmt.Errorf("cannot index slice/array with type %s", index.Type())
190 }
191 if x < 0 || x >= int64(v.Len()) {
192 return reflect.Value{}, fmt.Errorf("index out of range: %d", x)
193 }
194 v = v.Index(int(x))
195 case reflect.Map:
196 index, err := prepareArg(index, v.Type().Key())
197 if err != nil {
198 return reflect.Value{}, err
199 }
200 if x := v.MapIndex(index); x.IsValid() {
201 v = x
202 } else {
203 v = reflect.Zero(v.Type().Elem())
204 }
205 case reflect.Invalid:
206
207 panic("unreachable")
208 default:
209 return reflect.Value{}, fmt.Errorf("can't index item of type %s", v.Type())
210 }
211 }
212 return v, nil
213 }
214
215
216
217
218 func length(item interface{}) (int, error) {
219 v := reflect.ValueOf(item)
220 if !v.IsValid() {
221 return 0, fmt.Errorf("len of untyped nil")
222 }
223 v, isNil := indirect(v)
224 if isNil {
225 return 0, fmt.Errorf("len of nil pointer")
226 }
227 switch v.Kind() {
228 case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
229 return v.Len(), nil
230 }
231 return 0, fmt.Errorf("len of type %s", v.Type())
232 }
233
234
235
236
237
238 func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
239 v := indirectInterface(fn)
240 if !v.IsValid() {
241 return reflect.Value{}, fmt.Errorf("call of nil")
242 }
243 typ := v.Type()
244 if typ.Kind() != reflect.Func {
245 return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
246 }
247 if !goodFunc(typ) {
248 return reflect.Value{}, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
249 }
250 numIn := typ.NumIn()
251 var dddType reflect.Type
252 if typ.IsVariadic() {
253 if len(args) < numIn-1 {
254 return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
255 }
256 dddType = typ.In(numIn - 1).Elem()
257 } else {
258 if len(args) != numIn {
259 return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
260 }
261 }
262 argv := make([]reflect.Value, len(args))
263 for i, arg := range args {
264 value := indirectInterface(arg)
265
266 var argType reflect.Type
267 if !typ.IsVariadic() || i < numIn-1 {
268 argType = typ.In(i)
269 } else {
270 argType = dddType
271 }
272
273 var err error
274 if argv[i], err = prepareArg(value, argType); err != nil {
275 return reflect.Value{}, fmt.Errorf("arg %d: %s", i, err)
276 }
277 }
278 result := v.Call(argv)
279 if len(result) == 2 && !result[1].IsNil() {
280 return result[0], result[1].Interface().(error)
281 }
282 return result[0], nil
283 }
284
285
286
287 func truth(arg reflect.Value) bool {
288 t, _ := isTrue(indirectInterface(arg))
289 return t
290 }
291
292
293
294 func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
295 if !truth(arg0) {
296 return arg0
297 }
298 for i := range args {
299 arg0 = args[i]
300 if !truth(arg0) {
301 break
302 }
303 }
304 return arg0
305 }
306
307
308
309 func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
310 if truth(arg0) {
311 return arg0
312 }
313 for i := range args {
314 arg0 = args[i]
315 if truth(arg0) {
316 break
317 }
318 }
319 return arg0
320 }
321
322
323 func not(arg reflect.Value) bool {
324 return !truth(arg)
325 }
326
327
328
329
330
331 var (
332 errBadComparisonType = errors.New("invalid type for comparison")
333 errBadComparison = errors.New("incompatible types for comparison")
334 errNoComparison = errors.New("missing argument for comparison")
335 )
336
337 type kind int
338
339 const (
340 invalidKind kind = iota
341 boolKind
342 complexKind
343 intKind
344 floatKind
345 stringKind
346 uintKind
347 )
348
349 func basicKind(v reflect.Value) (kind, error) {
350 switch v.Kind() {
351 case reflect.Bool:
352 return boolKind, nil
353 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
354 return intKind, nil
355 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
356 return uintKind, nil
357 case reflect.Float32, reflect.Float64:
358 return floatKind, nil
359 case reflect.Complex64, reflect.Complex128:
360 return complexKind, nil
361 case reflect.String:
362 return stringKind, nil
363 }
364 return invalidKind, errBadComparisonType
365 }
366
367
368 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
369 v1 := indirectInterface(arg1)
370 k1, err := basicKind(v1)
371 if err != nil {
372 return false, err
373 }
374 if len(arg2) == 0 {
375 return false, errNoComparison
376 }
377 for _, arg := range arg2 {
378 v2 := indirectInterface(arg)
379 k2, err := basicKind(v2)
380 if err != nil {
381 return false, err
382 }
383 truth := false
384 if k1 != k2 {
385
386 switch {
387 case k1 == intKind && k2 == uintKind:
388 truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
389 case k1 == uintKind && k2 == intKind:
390 truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
391 default:
392 return false, errBadComparison
393 }
394 } else {
395 switch k1 {
396 case boolKind:
397 truth = v1.Bool() == v2.Bool()
398 case complexKind:
399 truth = v1.Complex() == v2.Complex()
400 case floatKind:
401 truth = v1.Float() == v2.Float()
402 case intKind:
403 truth = v1.Int() == v2.Int()
404 case stringKind:
405 truth = v1.String() == v2.String()
406 case uintKind:
407 truth = v1.Uint() == v2.Uint()
408 default:
409 panic("invalid kind")
410 }
411 }
412 if truth {
413 return true, nil
414 }
415 }
416 return false, nil
417 }
418
419
420 func ne(arg1, arg2 reflect.Value) (bool, error) {
421
422 equal, err := eq(arg1, arg2)
423 return !equal, err
424 }
425
426
427 func lt(arg1, arg2 reflect.Value) (bool, error) {
428 v1 := indirectInterface(arg1)
429 k1, err := basicKind(v1)
430 if err != nil {
431 return false, err
432 }
433 v2 := indirectInterface(arg2)
434 k2, err := basicKind(v2)
435 if err != nil {
436 return false, err
437 }
438 truth := false
439 if k1 != k2 {
440
441 switch {
442 case k1 == intKind && k2 == uintKind:
443 truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
444 case k1 == uintKind && k2 == intKind:
445 truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
446 default:
447 return false, errBadComparison
448 }
449 } else {
450 switch k1 {
451 case boolKind, complexKind:
452 return false, errBadComparisonType
453 case floatKind:
454 truth = v1.Float() < v2.Float()
455 case intKind:
456 truth = v1.Int() < v2.Int()
457 case stringKind:
458 truth = v1.String() < v2.String()
459 case uintKind:
460 truth = v1.Uint() < v2.Uint()
461 default:
462 panic("invalid kind")
463 }
464 }
465 return truth, nil
466 }
467
468
469 func le(arg1, arg2 reflect.Value) (bool, error) {
470
471 lessThan, err := lt(arg1, arg2)
472 if lessThan || err != nil {
473 return lessThan, err
474 }
475 return eq(arg1, arg2)
476 }
477
478
479 func gt(arg1, arg2 reflect.Value) (bool, error) {
480
481 lessOrEqual, err := le(arg1, arg2)
482 if err != nil {
483 return false, err
484 }
485 return !lessOrEqual, nil
486 }
487
488
489 func ge(arg1, arg2 reflect.Value) (bool, error) {
490
491 lessThan, err := lt(arg1, arg2)
492 if err != nil {
493 return false, err
494 }
495 return !lessThan, nil
496 }
497
498
499
500 var (
501 htmlQuot = []byte(""")
502 htmlApos = []byte("'")
503 htmlAmp = []byte("&")
504 htmlLt = []byte("<")
505 htmlGt = []byte(">")
506 htmlNull = []byte("\uFFFD")
507 )
508
509
510 func HTMLEscape(w io.Writer, b []byte) {
511 last := 0
512 for i, c := range b {
513 var html []byte
514 switch c {
515 case '\000':
516 html = htmlNull
517 case '"':
518 html = htmlQuot
519 case '\'':
520 html = htmlApos
521 case '&':
522 html = htmlAmp
523 case '<':
524 html = htmlLt
525 case '>':
526 html = htmlGt
527 default:
528 continue
529 }
530 w.Write(b[last:i])
531 w.Write(html)
532 last = i + 1
533 }
534 w.Write(b[last:])
535 }
536
537
538 func HTMLEscapeString(s string) string {
539
540 if !strings.ContainsAny(s, "'\"&<>\000") {
541 return s
542 }
543 var b bytes.Buffer
544 HTMLEscape(&b, []byte(s))
545 return b.String()
546 }
547
548
549
550 func HTMLEscaper(args ...interface{}) string {
551 return HTMLEscapeString(evalArgs(args))
552 }
553
554
555
556 var (
557 jsLowUni = []byte(`\u00`)
558 hex = []byte("0123456789ABCDEF")
559
560 jsBackslash = []byte(`\\`)
561 jsApos = []byte(`\'`)
562 jsQuot = []byte(`\"`)
563 jsLt = []byte(`\x3C`)
564 jsGt = []byte(`\x3E`)
565 )
566
567
568 func JSEscape(w io.Writer, b []byte) {
569 last := 0
570 for i := 0; i < len(b); i++ {
571 c := b[i]
572
573 if !jsIsSpecial(rune(c)) {
574
575 continue
576 }
577 w.Write(b[last:i])
578
579 if c < utf8.RuneSelf {
580
581
582 switch c {
583 case '\\':
584 w.Write(jsBackslash)
585 case '\'':
586 w.Write(jsApos)
587 case '"':
588 w.Write(jsQuot)
589 case '<':
590 w.Write(jsLt)
591 case '>':
592 w.Write(jsGt)
593 default:
594 w.Write(jsLowUni)
595 t, b := c>>4, c&0x0f
596 w.Write(hex[t : t+1])
597 w.Write(hex[b : b+1])
598 }
599 } else {
600
601 r, size := utf8.DecodeRune(b[i:])
602 if unicode.IsPrint(r) {
603 w.Write(b[i : i+size])
604 } else {
605 fmt.Fprintf(w, "\\u%04X", r)
606 }
607 i += size - 1
608 }
609 last = i + 1
610 }
611 w.Write(b[last:])
612 }
613
614
615 func JSEscapeString(s string) string {
616
617 if strings.IndexFunc(s, jsIsSpecial) < 0 {
618 return s
619 }
620 var b bytes.Buffer
621 JSEscape(&b, []byte(s))
622 return b.String()
623 }
624
625 func jsIsSpecial(r rune) bool {
626 switch r {
627 case '\\', '\'', '"', '<', '>':
628 return true
629 }
630 return r < ' ' || utf8.RuneSelf <= r
631 }
632
633
634
635 func JSEscaper(args ...interface{}) string {
636 return JSEscapeString(evalArgs(args))
637 }
638
639
640
641 func URLQueryEscaper(args ...interface{}) string {
642 return url.QueryEscape(evalArgs(args))
643 }
644
645
646
647
648
649
650 func evalArgs(args []interface{}) string {
651 ok := false
652 var s string
653
654 if len(args) == 1 {
655 s, ok = args[0].(string)
656 }
657 if !ok {
658 for i, arg := range args {
659 a, ok := printableValue(reflect.ValueOf(arg))
660 if ok {
661 args[i] = a
662 }
663 }
664 s = fmt.Sprint(args...)
665 }
666 return s
667 }
668
View as plain text