1
2
3
4
5
6 package base64
7
8 import (
9 "encoding/binary"
10 "io"
11 "strconv"
12 )
13
14
17
18
19
20
21
22
23 type Encoding struct {
24 encode [64]byte
25 decodeMap [256]byte
26 padChar rune
27 strict bool
28 }
29
30 const (
31 StdPadding rune = '='
32 NoPadding rune = -1
33 )
34
35 const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
36 const encodeURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
37
38
39
40
41
42
43 func NewEncoding(encoder string) *Encoding {
44 if len(encoder) != 64 {
45 panic("encoding alphabet is not 64-bytes long")
46 }
47 for i := 0; i < len(encoder); i++ {
48 if encoder[i] == '\n' || encoder[i] == '\r' {
49 panic("encoding alphabet contains newline character")
50 }
51 }
52
53 e := new(Encoding)
54 e.padChar = StdPadding
55 copy(e.encode[:], encoder)
56
57 for i := 0; i < len(e.decodeMap); i++ {
58 e.decodeMap[i] = 0xFF
59 }
60 for i := 0; i < len(encoder); i++ {
61 e.decodeMap[encoder[i]] = byte(i)
62 }
63 return e
64 }
65
66
67
68
69
70
71 func (enc Encoding) WithPadding(padding rune) *Encoding {
72 if padding == '\r' || padding == '\n' || padding > 0xff {
73 panic("invalid padding")
74 }
75
76 for i := 0; i < len(enc.encode); i++ {
77 if rune(enc.encode[i]) == padding {
78 panic("padding contained in alphabet")
79 }
80 }
81
82 enc.padChar = padding
83 return &enc
84 }
85
86
87
88
89 func (enc Encoding) Strict() *Encoding {
90 enc.strict = true
91 return &enc
92 }
93
94
95
96 var StdEncoding = NewEncoding(encodeStd)
97
98
99
100 var URLEncoding = NewEncoding(encodeURL)
101
102
103
104
105 var RawStdEncoding = StdEncoding.WithPadding(NoPadding)
106
107
108
109
110 var RawURLEncoding = URLEncoding.WithPadding(NoPadding)
111
112
115
116
117
118
119
120
121
122 func (enc *Encoding) Encode(dst, src []byte) {
123 if len(src) == 0 {
124 return
125 }
126
127 di, si := 0, 0
128 n := (len(src) / 3) * 3
129 for si < n {
130
131 val := uint(src[si+0])<<16 | uint(src[si+1])<<8 | uint(src[si+2])
132
133 dst[di+0] = enc.encode[val>>18&0x3F]
134 dst[di+1] = enc.encode[val>>12&0x3F]
135 dst[di+2] = enc.encode[val>>6&0x3F]
136 dst[di+3] = enc.encode[val&0x3F]
137
138 si += 3
139 di += 4
140 }
141
142 remain := len(src) - si
143 if remain == 0 {
144 return
145 }
146
147 val := uint(src[si+0]) << 16
148 if remain == 2 {
149 val |= uint(src[si+1]) << 8
150 }
151
152 dst[di+0] = enc.encode[val>>18&0x3F]
153 dst[di+1] = enc.encode[val>>12&0x3F]
154
155 switch remain {
156 case 2:
157 dst[di+2] = enc.encode[val>>6&0x3F]
158 if enc.padChar != NoPadding {
159 dst[di+3] = byte(enc.padChar)
160 }
161 case 1:
162 if enc.padChar != NoPadding {
163 dst[di+2] = byte(enc.padChar)
164 dst[di+3] = byte(enc.padChar)
165 }
166 }
167 }
168
169
170 func (enc *Encoding) EncodeToString(src []byte) string {
171 buf := make([]byte, enc.EncodedLen(len(src)))
172 enc.Encode(buf, src)
173 return string(buf)
174 }
175
176 type encoder struct {
177 err error
178 enc *Encoding
179 w io.Writer
180 buf [3]byte
181 nbuf int
182 out [1024]byte
183 }
184
185 func (e *encoder) Write(p []byte) (n int, err error) {
186 if e.err != nil {
187 return 0, e.err
188 }
189
190
191 if e.nbuf > 0 {
192 var i int
193 for i = 0; i < len(p) && e.nbuf < 3; i++ {
194 e.buf[e.nbuf] = p[i]
195 e.nbuf++
196 }
197 n += i
198 p = p[i:]
199 if e.nbuf < 3 {
200 return
201 }
202 e.enc.Encode(e.out[:], e.buf[:])
203 if _, e.err = e.w.Write(e.out[:4]); e.err != nil {
204 return n, e.err
205 }
206 e.nbuf = 0
207 }
208
209
210 for len(p) >= 3 {
211 nn := len(e.out) / 4 * 3
212 if nn > len(p) {
213 nn = len(p)
214 nn -= nn % 3
215 }
216 e.enc.Encode(e.out[:], p[:nn])
217 if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil {
218 return n, e.err
219 }
220 n += nn
221 p = p[nn:]
222 }
223
224
225 for i := 0; i < len(p); i++ {
226 e.buf[i] = p[i]
227 }
228 e.nbuf = len(p)
229 n += len(p)
230 return
231 }
232
233
234
235 func (e *encoder) Close() error {
236
237 if e.err == nil && e.nbuf > 0 {
238 e.enc.Encode(e.out[:], e.buf[:e.nbuf])
239 _, e.err = e.w.Write(e.out[:e.enc.EncodedLen(e.nbuf)])
240 e.nbuf = 0
241 }
242 return e.err
243 }
244
245
246
247
248
249
250 func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
251 return &encoder{enc: enc, w: w}
252 }
253
254
255
256 func (enc *Encoding) EncodedLen(n int) int {
257 if enc.padChar == NoPadding {
258 return (n*8 + 5) / 6
259 }
260 return (n + 2) / 3 * 4
261 }
262
263
266
267 type CorruptInputError int64
268
269 func (e CorruptInputError) Error() string {
270 return "illegal base64 data at input byte " + strconv.FormatInt(int64(e), 10)
271 }
272
273
274
275
276
277
278 func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err error) {
279
280 var dbuf [4]byte
281 dinc, dlen := 3, 4
282
283 for j := 0; j < len(dbuf); j++ {
284 if len(src) == si {
285 switch {
286 case j == 0:
287 return si, 0, nil
288 case j == 1, enc.padChar != NoPadding:
289 return si, 0, CorruptInputError(si - j)
290 }
291 dinc, dlen = j-1, j
292 break
293 }
294 in := src[si]
295 si++
296
297 out := enc.decodeMap[in]
298 if out != 0xff {
299 dbuf[j] = out
300 continue
301 }
302
303 if in == '\n' || in == '\r' {
304 j--
305 continue
306 }
307
308 if rune(in) != enc.padChar {
309 return si, 0, CorruptInputError(si - 1)
310 }
311
312
313 switch j {
314 case 0, 1:
315
316 return si, 0, CorruptInputError(si - 1)
317 case 2:
318
319
320 for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
321 si++
322 }
323 if si == len(src) {
324
325 return si, 0, CorruptInputError(len(src))
326 }
327 if rune(src[si]) != enc.padChar {
328
329 return si, 0, CorruptInputError(si - 1)
330 }
331
332 si++
333 }
334
335
336 for si < len(src) && (src[si] == '\n' || src[si] == '\r') {
337 si++
338 }
339 if si < len(src) {
340
341 err = CorruptInputError(si)
342 }
343 dinc, dlen = 3, j
344 break
345 }
346
347
348 val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
349 dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
350 switch dlen {
351 case 4:
352 dst[2] = dbuf[2]
353 dbuf[2] = 0
354 fallthrough
355 case 3:
356 dst[1] = dbuf[1]
357 if enc.strict && dbuf[2] != 0 {
358 return si, 0, CorruptInputError(si - 1)
359 }
360 dbuf[1] = 0
361 fallthrough
362 case 2:
363 dst[0] = dbuf[0]
364 if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
365 return si, 0, CorruptInputError(si - 2)
366 }
367 }
368 dst = dst[dinc:]
369
370 return si, dlen - 1, err
371 }
372
373
374 func (enc *Encoding) DecodeString(s string) ([]byte, error) {
375 dbuf := make([]byte, enc.DecodedLen(len(s)))
376 n, err := enc.Decode(dbuf, []byte(s))
377 return dbuf[:n], err
378 }
379
380 type decoder struct {
381 err error
382 readErr error
383 enc *Encoding
384 r io.Reader
385 buf [1024]byte
386 nbuf int
387 out []byte
388 outbuf [1024 / 4 * 3]byte
389 }
390
391 func (d *decoder) Read(p []byte) (n int, err error) {
392
393 if len(d.out) > 0 {
394 n = copy(p, d.out)
395 d.out = d.out[n:]
396 return n, nil
397 }
398
399 if d.err != nil {
400 return 0, d.err
401 }
402
403
404
405
406 for d.nbuf < 4 && d.readErr == nil {
407 nn := len(p) / 3 * 4
408 if nn < 4 {
409 nn = 4
410 }
411 if nn > len(d.buf) {
412 nn = len(d.buf)
413 }
414 nn, d.readErr = d.r.Read(d.buf[d.nbuf:nn])
415 d.nbuf += nn
416 }
417
418 if d.nbuf < 4 {
419 if d.enc.padChar == NoPadding && d.nbuf > 0 {
420
421 var nw int
422 nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:d.nbuf])
423 d.nbuf = 0
424 d.out = d.outbuf[:nw]
425 n = copy(p, d.out)
426 d.out = d.out[n:]
427 if n > 0 || len(p) == 0 && len(d.out) > 0 {
428 return n, nil
429 }
430 if d.err != nil {
431 return 0, d.err
432 }
433 }
434 d.err = d.readErr
435 if d.err == io.EOF && d.nbuf > 0 {
436 d.err = io.ErrUnexpectedEOF
437 }
438 return 0, d.err
439 }
440
441
442 nr := d.nbuf / 4 * 4
443 nw := d.nbuf / 4 * 3
444 if nw > len(p) {
445 nw, d.err = d.enc.Decode(d.outbuf[:], d.buf[:nr])
446 d.out = d.outbuf[:nw]
447 n = copy(p, d.out)
448 d.out = d.out[n:]
449 } else {
450 n, d.err = d.enc.Decode(p, d.buf[:nr])
451 }
452 d.nbuf -= nr
453 copy(d.buf[:d.nbuf], d.buf[nr:])
454 return n, d.err
455 }
456
457
458
459
460
461
462 func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
463 if len(src) == 0 {
464 return 0, nil
465 }
466
467 si := 0
468 ilen := len(src)
469 olen := len(dst)
470 for strconv.IntSize >= 64 && ilen-si >= 8 && olen-n >= 8 {
471 if ok := enc.decode64(dst[n:], src[si:]); ok {
472 n += 6
473 si += 8
474 } else {
475 var ninc int
476 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
477 n += ninc
478 if err != nil {
479 return n, err
480 }
481 }
482 }
483
484 for ilen-si >= 4 && olen-n >= 4 {
485 if ok := enc.decode32(dst[n:], src[si:]); ok {
486 n += 3
487 si += 4
488 } else {
489 var ninc int
490 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
491 n += ninc
492 if err != nil {
493 return n, err
494 }
495 }
496 }
497
498 for si < len(src) {
499 var ninc int
500 si, ninc, err = enc.decodeQuantum(dst[n:], src, si)
501 n += ninc
502 if err != nil {
503 return n, err
504 }
505 }
506 return n, err
507 }
508
509
510
511
512 func (enc *Encoding) decode32(dst, src []byte) bool {
513 var dn, n uint32
514 if n = uint32(enc.decodeMap[src[0]]); n == 0xff {
515 return false
516 }
517 dn |= n << 26
518 if n = uint32(enc.decodeMap[src[1]]); n == 0xff {
519 return false
520 }
521 dn |= n << 20
522 if n = uint32(enc.decodeMap[src[2]]); n == 0xff {
523 return false
524 }
525 dn |= n << 14
526 if n = uint32(enc.decodeMap[src[3]]); n == 0xff {
527 return false
528 }
529 dn |= n << 8
530
531 binary.BigEndian.PutUint32(dst, dn)
532 return true
533 }
534
535
536
537
538 func (enc *Encoding) decode64(dst, src []byte) bool {
539 var dn, n uint64
540 if n = uint64(enc.decodeMap[src[0]]); n == 0xff {
541 return false
542 }
543 dn |= n << 58
544 if n = uint64(enc.decodeMap[src[1]]); n == 0xff {
545 return false
546 }
547 dn |= n << 52
548 if n = uint64(enc.decodeMap[src[2]]); n == 0xff {
549 return false
550 }
551 dn |= n << 46
552 if n = uint64(enc.decodeMap[src[3]]); n == 0xff {
553 return false
554 }
555 dn |= n << 40
556 if n = uint64(enc.decodeMap[src[4]]); n == 0xff {
557 return false
558 }
559 dn |= n << 34
560 if n = uint64(enc.decodeMap[src[5]]); n == 0xff {
561 return false
562 }
563 dn |= n << 28
564 if n = uint64(enc.decodeMap[src[6]]); n == 0xff {
565 return false
566 }
567 dn |= n << 22
568 if n = uint64(enc.decodeMap[src[7]]); n == 0xff {
569 return false
570 }
571 dn |= n << 16
572
573 binary.BigEndian.PutUint64(dst, dn)
574 return true
575 }
576
577 type newlineFilteringReader struct {
578 wrapped io.Reader
579 }
580
581 func (r *newlineFilteringReader) Read(p []byte) (int, error) {
582 n, err := r.wrapped.Read(p)
583 for n > 0 {
584 offset := 0
585 for i, b := range p[:n] {
586 if b != '\r' && b != '\n' {
587 if i != offset {
588 p[offset] = b
589 }
590 offset++
591 }
592 }
593 if offset > 0 {
594 return offset, err
595 }
596
597 n, err = r.wrapped.Read(p)
598 }
599 return n, err
600 }
601
602
603 func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
604 return &decoder{enc: enc, r: &newlineFilteringReader{r}}
605 }
606
607
608
609 func (enc *Encoding) DecodedLen(n int) int {
610 if enc.padChar == NoPadding {
611
612 return n * 6 / 8
613 }
614
615 return n / 4 * 3
616 }
617
View as plain text