1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "crypto/cipher"
12 "crypto/subtle"
13 "crypto/x509"
14 "errors"
15 "fmt"
16 "io"
17 "net"
18 "sync"
19 "sync/atomic"
20 "time"
21 )
22
23
24
25 type Conn struct {
26
27 conn net.Conn
28 isClient bool
29
30
31 handshakeMutex sync.Mutex
32
33
34
35 handshakeCond *sync.Cond
36 handshakeErr error
37 vers uint16
38 haveVers bool
39 config *Config
40
41
42 handshakeComplete bool
43
44
45
46 handshakes int
47 didResume bool
48 cipherSuite uint16
49 ocspResponse []byte
50 scts [][]byte
51 peerCertificates []*x509.Certificate
52
53
54 verifiedChains [][]*x509.Certificate
55
56 serverName string
57
58
59
60 secureRenegotiation bool
61
62
63
64
65
66 clientFinishedIsFirst bool
67
68
69 closeNotifyErr error
70
71
72 closeNotifySent bool
73
74
75
76
77
78 clientFinished [12]byte
79 serverFinished [12]byte
80
81 clientProtocol string
82 clientProtocolFallback bool
83
84
85 in, out halfConn
86 rawInput *block
87 input *block
88 hand bytes.Buffer
89 buffering bool
90 sendBuf []byte
91
92
93
94 bytesSent int64
95 packetsSent int64
96
97
98
99 warnCount int
100
101
102
103
104 activeCall int32
105
106 tmp [16]byte
107 }
108
109
110
111
112
113
114 func (c *Conn) LocalAddr() net.Addr {
115 return c.conn.LocalAddr()
116 }
117
118
119 func (c *Conn) RemoteAddr() net.Addr {
120 return c.conn.RemoteAddr()
121 }
122
123
124
125
126 func (c *Conn) SetDeadline(t time.Time) error {
127 return c.conn.SetDeadline(t)
128 }
129
130
131
132 func (c *Conn) SetReadDeadline(t time.Time) error {
133 return c.conn.SetReadDeadline(t)
134 }
135
136
137
138
139 func (c *Conn) SetWriteDeadline(t time.Time) error {
140 return c.conn.SetWriteDeadline(t)
141 }
142
143
144
145 type halfConn struct {
146 sync.Mutex
147
148 err error
149 version uint16
150 cipher interface{}
151 mac macFunction
152 seq [8]byte
153 bfree *block
154 additionalData [13]byte
155
156 nextCipher interface{}
157 nextMac macFunction
158
159
160 inDigestBuf, outDigestBuf []byte
161 }
162
163 func (hc *halfConn) setErrorLocked(err error) error {
164 hc.err = err
165 return err
166 }
167
168
169
170 func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
171 hc.version = version
172 hc.nextCipher = cipher
173 hc.nextMac = mac
174 }
175
176
177
178 func (hc *halfConn) changeCipherSpec() error {
179 if hc.nextCipher == nil {
180 return alertInternalError
181 }
182 hc.cipher = hc.nextCipher
183 hc.mac = hc.nextMac
184 hc.nextCipher = nil
185 hc.nextMac = nil
186 for i := range hc.seq {
187 hc.seq[i] = 0
188 }
189 return nil
190 }
191
192
193 func (hc *halfConn) incSeq() {
194 for i := 7; i >= 0; i-- {
195 hc.seq[i]++
196 if hc.seq[i] != 0 {
197 return
198 }
199 }
200
201
202
203
204 panic("TLS: sequence number wraparound")
205 }
206
207
208
209
210 func extractPadding(payload []byte) (toRemove int, good byte) {
211 if len(payload) < 1 {
212 return 0, 0
213 }
214
215 paddingLen := payload[len(payload)-1]
216 t := uint(len(payload)-1) - uint(paddingLen)
217
218 good = byte(int32(^t) >> 31)
219
220
221 toCheck := 256
222
223 if toCheck > len(payload) {
224 toCheck = len(payload)
225 }
226
227 for i := 0; i < toCheck; i++ {
228 t := uint(paddingLen) - uint(i)
229
230 mask := byte(int32(^t) >> 31)
231 b := payload[len(payload)-1-i]
232 good &^= mask&paddingLen ^ mask&b
233 }
234
235
236
237 good &= good << 4
238 good &= good << 2
239 good &= good << 1
240 good = uint8(int8(good) >> 7)
241
242 toRemove = int(paddingLen) + 1
243 return
244 }
245
246
247
248
249 func extractPaddingSSL30(payload []byte) (toRemove int, good byte) {
250 if len(payload) < 1 {
251 return 0, 0
252 }
253
254 paddingLen := int(payload[len(payload)-1]) + 1
255 if paddingLen > len(payload) {
256 return 0, 0
257 }
258
259 return paddingLen, 255
260 }
261
262 func roundUp(a, b int) int {
263 return a + (b-a%b)%b
264 }
265
266
267 type cbcMode interface {
268 cipher.BlockMode
269 SetIV([]byte)
270 }
271
272
273
274
275 func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) {
276
277 payload := b.data[recordHeaderLen:]
278
279 macSize := 0
280 if hc.mac != nil {
281 macSize = hc.mac.Size()
282 }
283
284 paddingGood := byte(255)
285 paddingLen := 0
286 explicitIVLen := 0
287
288
289 if hc.cipher != nil {
290 switch c := hc.cipher.(type) {
291 case cipher.Stream:
292 c.XORKeyStream(payload, payload)
293 case aead:
294 explicitIVLen = c.explicitNonceLen()
295 if len(payload) < explicitIVLen {
296 return false, 0, alertBadRecordMAC
297 }
298 nonce := payload[:explicitIVLen]
299 payload = payload[explicitIVLen:]
300
301 if len(nonce) == 0 {
302 nonce = hc.seq[:]
303 }
304
305 copy(hc.additionalData[:], hc.seq[:])
306 copy(hc.additionalData[8:], b.data[:3])
307 n := len(payload) - c.Overhead()
308 hc.additionalData[11] = byte(n >> 8)
309 hc.additionalData[12] = byte(n)
310 var err error
311 payload, err = c.Open(payload[:0], nonce, payload, hc.additionalData[:])
312 if err != nil {
313 return false, 0, alertBadRecordMAC
314 }
315 b.resize(recordHeaderLen + explicitIVLen + len(payload))
316 case cbcMode:
317 blockSize := c.BlockSize()
318 if hc.version >= VersionTLS11 {
319 explicitIVLen = blockSize
320 }
321
322 if len(payload)%blockSize != 0 || len(payload) < roundUp(explicitIVLen+macSize+1, blockSize) {
323 return false, 0, alertBadRecordMAC
324 }
325
326 if explicitIVLen > 0 {
327 c.SetIV(payload[:explicitIVLen])
328 payload = payload[explicitIVLen:]
329 }
330 c.CryptBlocks(payload, payload)
331 if hc.version == VersionSSL30 {
332 paddingLen, paddingGood = extractPaddingSSL30(payload)
333 } else {
334 paddingLen, paddingGood = extractPadding(payload)
335
336
337
338
339
340
341
342 }
343 default:
344 panic("unknown cipher type")
345 }
346 }
347
348
349 if hc.mac != nil {
350 if len(payload) < macSize {
351 return false, 0, alertBadRecordMAC
352 }
353
354
355 n := len(payload) - macSize - paddingLen
356 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
357 b.data[3] = byte(n >> 8)
358 b.data[4] = byte(n)
359 remoteMAC := payload[n : n+macSize]
360 localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n], payload[n+macSize:])
361
362 if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
363 return false, 0, alertBadRecordMAC
364 }
365 hc.inDigestBuf = localMAC
366
367 b.resize(recordHeaderLen + explicitIVLen + n)
368 }
369 hc.incSeq()
370
371 return true, recordHeaderLen + explicitIVLen, 0
372 }
373
374
375
376
377
378
379 func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) {
380 overrun := len(payload) % blockSize
381 paddingLen := blockSize - overrun
382 prefix = payload[:len(payload)-overrun]
383 finalBlock = make([]byte, blockSize)
384 copy(finalBlock, payload[len(payload)-overrun:])
385 for i := overrun; i < blockSize; i++ {
386 finalBlock[i] = byte(paddingLen - 1)
387 }
388 return
389 }
390
391
392 func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) {
393
394 if hc.mac != nil {
395 mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:], nil)
396
397 n := len(b.data)
398 b.resize(n + len(mac))
399 copy(b.data[n:], mac)
400 hc.outDigestBuf = mac
401 }
402
403 payload := b.data[recordHeaderLen:]
404
405
406 if hc.cipher != nil {
407 switch c := hc.cipher.(type) {
408 case cipher.Stream:
409 c.XORKeyStream(payload, payload)
410 case aead:
411 payloadLen := len(b.data) - recordHeaderLen - explicitIVLen
412 b.resize(len(b.data) + c.Overhead())
413 nonce := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
414 if len(nonce) == 0 {
415 nonce = hc.seq[:]
416 }
417 payload := b.data[recordHeaderLen+explicitIVLen:]
418 payload = payload[:payloadLen]
419
420 copy(hc.additionalData[:], hc.seq[:])
421 copy(hc.additionalData[8:], b.data[:3])
422 hc.additionalData[11] = byte(payloadLen >> 8)
423 hc.additionalData[12] = byte(payloadLen)
424
425 c.Seal(payload[:0], nonce, payload, hc.additionalData[:])
426 case cbcMode:
427 blockSize := c.BlockSize()
428 if explicitIVLen > 0 {
429 c.SetIV(payload[:explicitIVLen])
430 payload = payload[explicitIVLen:]
431 }
432 prefix, finalBlock := padToBlockSize(payload, blockSize)
433 b.resize(recordHeaderLen + explicitIVLen + len(prefix) + len(finalBlock))
434 c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen:], prefix)
435 c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen+len(prefix):], finalBlock)
436 default:
437 panic("unknown cipher type")
438 }
439 }
440
441
442 n := len(b.data) - recordHeaderLen
443 b.data[3] = byte(n >> 8)
444 b.data[4] = byte(n)
445 hc.incSeq()
446
447 return true, 0
448 }
449
450
451 type block struct {
452 data []byte
453 off int
454 link *block
455 }
456
457
458 func (b *block) resize(n int) {
459 if n > cap(b.data) {
460 b.reserve(n)
461 }
462 b.data = b.data[0:n]
463 }
464
465
466 func (b *block) reserve(n int) {
467 if cap(b.data) >= n {
468 return
469 }
470 m := cap(b.data)
471 if m == 0 {
472 m = 1024
473 }
474 for m < n {
475 m *= 2
476 }
477 data := make([]byte, len(b.data), m)
478 copy(data, b.data)
479 b.data = data
480 }
481
482
483
484 func (b *block) readFromUntil(r io.Reader, n int) error {
485
486 if len(b.data) >= n {
487 return nil
488 }
489
490
491 b.reserve(n)
492 for {
493 m, err := r.Read(b.data[len(b.data):cap(b.data)])
494 b.data = b.data[0 : len(b.data)+m]
495 if len(b.data) >= n {
496
497
498 break
499 }
500 if err != nil {
501 return err
502 }
503 }
504 return nil
505 }
506
507 func (b *block) Read(p []byte) (n int, err error) {
508 n = copy(p, b.data[b.off:])
509 b.off += n
510 return
511 }
512
513
514 func (hc *halfConn) newBlock() *block {
515 b := hc.bfree
516 if b == nil {
517 return new(block)
518 }
519 hc.bfree = b.link
520 b.link = nil
521 b.resize(0)
522 return b
523 }
524
525
526
527
528
529 func (hc *halfConn) freeBlock(b *block) {
530 b.link = hc.bfree
531 hc.bfree = b
532 }
533
534
535
536
537 func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
538 if len(b.data) <= n {
539 return b, nil
540 }
541 bb := hc.newBlock()
542 bb.resize(len(b.data) - n)
543 copy(bb.data, b.data[n:])
544 b.data = b.data[0:n]
545 return b, bb
546 }
547
548
549 type RecordHeaderError struct {
550
551 Msg string
552
553
554 RecordHeader [5]byte
555 }
556
557 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
558
559 func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) {
560 err.Msg = msg
561 copy(err.RecordHeader[:], c.rawInput.data)
562 return err
563 }
564
565
566
567
568 func (c *Conn) readRecord(want recordType) error {
569
570
571
572 switch want {
573 default:
574 c.sendAlert(alertInternalError)
575 return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
576 case recordTypeHandshake, recordTypeChangeCipherSpec:
577 if c.handshakeComplete {
578 c.sendAlert(alertInternalError)
579 return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake"))
580 }
581 case recordTypeApplicationData:
582 if !c.handshakeComplete {
583 c.sendAlert(alertInternalError)
584 return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake"))
585 }
586 }
587
588 Again:
589 if c.rawInput == nil {
590 c.rawInput = c.in.newBlock()
591 }
592 b := c.rawInput
593
594
595 if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
596
597
598
599
600
601
602 if e, ok := err.(net.Error); !ok || !e.Temporary() {
603 c.in.setErrorLocked(err)
604 }
605 return err
606 }
607 typ := recordType(b.data[0])
608
609
610
611
612
613 if want == recordTypeHandshake && typ == 0x80 {
614 c.sendAlert(alertProtocolVersion)
615 return c.in.setErrorLocked(c.newRecordHeaderError("unsupported SSLv2 handshake received"))
616 }
617
618 vers := uint16(b.data[1])<<8 | uint16(b.data[2])
619 n := int(b.data[3])<<8 | int(b.data[4])
620 if c.haveVers && vers != c.vers {
621 c.sendAlert(alertProtocolVersion)
622 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
623 return c.in.setErrorLocked(c.newRecordHeaderError(msg))
624 }
625 if n > maxCiphertext {
626 c.sendAlert(alertRecordOverflow)
627 msg := fmt.Sprintf("oversized record received with length %d", n)
628 return c.in.setErrorLocked(c.newRecordHeaderError(msg))
629 }
630 if !c.haveVers {
631
632
633
634
635 if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 {
636 c.sendAlert(alertUnexpectedMessage)
637 return c.in.setErrorLocked(c.newRecordHeaderError("first record does not look like a TLS handshake"))
638 }
639 }
640 if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
641 if err == io.EOF {
642 err = io.ErrUnexpectedEOF
643 }
644 if e, ok := err.(net.Error); !ok || !e.Temporary() {
645 c.in.setErrorLocked(err)
646 }
647 return err
648 }
649
650
651 b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
652 ok, off, alertValue := c.in.decrypt(b)
653 if !ok {
654 c.in.freeBlock(b)
655 return c.in.setErrorLocked(c.sendAlert(alertValue))
656 }
657 b.off = off
658 data := b.data[b.off:]
659 if len(data) > maxPlaintext {
660 err := c.sendAlert(alertRecordOverflow)
661 c.in.freeBlock(b)
662 return c.in.setErrorLocked(err)
663 }
664
665 if typ != recordTypeAlert && len(data) > 0 {
666
667 c.warnCount = 0
668 }
669
670 switch typ {
671 default:
672 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
673
674 case recordTypeAlert:
675 if len(data) != 2 {
676 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
677 break
678 }
679 if alert(data[1]) == alertCloseNotify {
680 c.in.setErrorLocked(io.EOF)
681 break
682 }
683 switch data[0] {
684 case alertLevelWarning:
685
686 c.in.freeBlock(b)
687
688 c.warnCount++
689 if c.warnCount > maxWarnAlertCount {
690 c.sendAlert(alertUnexpectedMessage)
691 return c.in.setErrorLocked(errors.New("tls: too many warn alerts"))
692 }
693
694 goto Again
695 case alertLevelError:
696 c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
697 default:
698 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
699 }
700
701 case recordTypeChangeCipherSpec:
702 if typ != want || len(data) != 1 || data[0] != 1 {
703 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
704 break
705 }
706
707 if c.hand.Len() > 0 {
708 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
709 break
710 }
711 err := c.in.changeCipherSpec()
712 if err != nil {
713 c.in.setErrorLocked(c.sendAlert(err.(alert)))
714 }
715
716 case recordTypeApplicationData:
717 if typ != want {
718 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
719 break
720 }
721 c.input = b
722 b = nil
723
724 case recordTypeHandshake:
725
726 if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) {
727 return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
728 }
729 c.hand.Write(data)
730 }
731
732 if b != nil {
733 c.in.freeBlock(b)
734 }
735 return c.in.err
736 }
737
738
739
740 func (c *Conn) sendAlertLocked(err alert) error {
741 switch err {
742 case alertNoRenegotiation, alertCloseNotify:
743 c.tmp[0] = alertLevelWarning
744 default:
745 c.tmp[0] = alertLevelError
746 }
747 c.tmp[1] = byte(err)
748
749 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
750 if err == alertCloseNotify {
751
752 return writeErr
753 }
754
755 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
756 }
757
758
759
760 func (c *Conn) sendAlert(err alert) error {
761 c.out.Lock()
762 defer c.out.Unlock()
763 return c.sendAlertLocked(err)
764 }
765
766 const (
767
768
769
770
771
772 tcpMSSEstimate = 1208
773
774
775
776
777 recordSizeBoostThreshold = 128 * 1024
778 )
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798 func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int {
799 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
800 return maxPlaintext
801 }
802
803 if c.bytesSent >= recordSizeBoostThreshold {
804 return maxPlaintext
805 }
806
807
808 macSize := 0
809 if c.out.mac != nil {
810 macSize = c.out.mac.Size()
811 }
812
813 payloadBytes := tcpMSSEstimate - recordHeaderLen - explicitIVLen
814 if c.out.cipher != nil {
815 switch ciph := c.out.cipher.(type) {
816 case cipher.Stream:
817 payloadBytes -= macSize
818 case cipher.AEAD:
819 payloadBytes -= ciph.Overhead()
820 case cbcMode:
821 blockSize := ciph.BlockSize()
822
823
824 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
825
826
827 payloadBytes -= macSize
828 default:
829 panic("unknown cipher type")
830 }
831 }
832
833
834 pkt := c.packetsSent
835 c.packetsSent++
836 if pkt > 1000 {
837 return maxPlaintext
838 }
839
840 n := payloadBytes * int(pkt+1)
841 if n > maxPlaintext {
842 n = maxPlaintext
843 }
844 return n
845 }
846
847
848 func (c *Conn) write(data []byte) (int, error) {
849 if c.buffering {
850 c.sendBuf = append(c.sendBuf, data...)
851 return len(data), nil
852 }
853
854 n, err := c.conn.Write(data)
855 c.bytesSent += int64(n)
856 return n, err
857 }
858
859 func (c *Conn) flush() (int, error) {
860 if len(c.sendBuf) == 0 {
861 return 0, nil
862 }
863
864 n, err := c.conn.Write(c.sendBuf)
865 c.bytesSent += int64(n)
866 c.sendBuf = nil
867 c.buffering = false
868 return n, err
869 }
870
871
872
873
874 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
875 b := c.out.newBlock()
876 defer c.out.freeBlock(b)
877
878 var n int
879 for len(data) > 0 {
880 explicitIVLen := 0
881 explicitIVIsSeq := false
882
883 var cbc cbcMode
884 if c.out.version >= VersionTLS11 {
885 var ok bool
886 if cbc, ok = c.out.cipher.(cbcMode); ok {
887 explicitIVLen = cbc.BlockSize()
888 }
889 }
890 if explicitIVLen == 0 {
891 if c, ok := c.out.cipher.(aead); ok {
892 explicitIVLen = c.explicitNonceLen()
893
894
895
896
897
898
899
900 explicitIVIsSeq = explicitIVLen > 0
901 }
902 }
903 m := len(data)
904 if maxPayload := c.maxPayloadSizeForWrite(typ, explicitIVLen); m > maxPayload {
905 m = maxPayload
906 }
907 b.resize(recordHeaderLen + explicitIVLen + m)
908 b.data[0] = byte(typ)
909 vers := c.vers
910 if vers == 0 {
911
912
913 vers = VersionTLS10
914 }
915 b.data[1] = byte(vers >> 8)
916 b.data[2] = byte(vers)
917 b.data[3] = byte(m >> 8)
918 b.data[4] = byte(m)
919 if explicitIVLen > 0 {
920 explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
921 if explicitIVIsSeq {
922 copy(explicitIV, c.out.seq[:])
923 } else {
924 if _, err := io.ReadFull(c.config.rand(), explicitIV); err != nil {
925 return n, err
926 }
927 }
928 }
929 copy(b.data[recordHeaderLen+explicitIVLen:], data)
930 c.out.encrypt(b, explicitIVLen)
931 if _, err := c.write(b.data); err != nil {
932 return n, err
933 }
934 n += m
935 data = data[m:]
936 }
937
938 if typ == recordTypeChangeCipherSpec {
939 if err := c.out.changeCipherSpec(); err != nil {
940 return n, c.sendAlertLocked(err.(alert))
941 }
942 }
943
944 return n, nil
945 }
946
947
948
949
950 func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
951 c.out.Lock()
952 defer c.out.Unlock()
953
954 return c.writeRecordLocked(typ, data)
955 }
956
957
958
959
960 func (c *Conn) readHandshake() (interface{}, error) {
961 for c.hand.Len() < 4 {
962 if err := c.in.err; err != nil {
963 return nil, err
964 }
965 if err := c.readRecord(recordTypeHandshake); err != nil {
966 return nil, err
967 }
968 }
969
970 data := c.hand.Bytes()
971 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
972 if n > maxHandshake {
973 c.sendAlertLocked(alertInternalError)
974 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
975 }
976 for c.hand.Len() < 4+n {
977 if err := c.in.err; err != nil {
978 return nil, err
979 }
980 if err := c.readRecord(recordTypeHandshake); err != nil {
981 return nil, err
982 }
983 }
984 data = c.hand.Next(4 + n)
985 var m handshakeMessage
986 switch data[0] {
987 case typeHelloRequest:
988 m = new(helloRequestMsg)
989 case typeClientHello:
990 m = new(clientHelloMsg)
991 case typeServerHello:
992 m = new(serverHelloMsg)
993 case typeNewSessionTicket:
994 m = new(newSessionTicketMsg)
995 case typeCertificate:
996 m = new(certificateMsg)
997 case typeCertificateRequest:
998 m = &certificateRequestMsg{
999 hasSignatureAndHash: c.vers >= VersionTLS12,
1000 }
1001 case typeCertificateStatus:
1002 m = new(certificateStatusMsg)
1003 case typeServerKeyExchange:
1004 m = new(serverKeyExchangeMsg)
1005 case typeServerHelloDone:
1006 m = new(serverHelloDoneMsg)
1007 case typeClientKeyExchange:
1008 m = new(clientKeyExchangeMsg)
1009 case typeCertificateVerify:
1010 m = &certificateVerifyMsg{
1011 hasSignatureAndHash: c.vers >= VersionTLS12,
1012 }
1013 case typeNextProtocol:
1014 m = new(nextProtoMsg)
1015 case typeFinished:
1016 m = new(finishedMsg)
1017 default:
1018 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1019 }
1020
1021
1022
1023
1024 data = append([]byte(nil), data...)
1025
1026 if !m.unmarshal(data) {
1027 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1028 }
1029 return m, nil
1030 }
1031
1032 var (
1033 errClosed = errors.New("tls: use of closed connection")
1034 errShutdown = errors.New("tls: protocol is shutdown")
1035 )
1036
1037
1038 func (c *Conn) Write(b []byte) (int, error) {
1039
1040 for {
1041 x := atomic.LoadInt32(&c.activeCall)
1042 if x&1 != 0 {
1043 return 0, errClosed
1044 }
1045 if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
1046 defer atomic.AddInt32(&c.activeCall, -2)
1047 break
1048 }
1049 }
1050
1051 if err := c.Handshake(); err != nil {
1052 return 0, err
1053 }
1054
1055 c.out.Lock()
1056 defer c.out.Unlock()
1057
1058 if err := c.out.err; err != nil {
1059 return 0, err
1060 }
1061
1062 if !c.handshakeComplete {
1063 return 0, alertInternalError
1064 }
1065
1066 if c.closeNotifySent {
1067 return 0, errShutdown
1068 }
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079 var m int
1080 if len(b) > 1 && c.vers <= VersionTLS10 {
1081 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1082 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1083 if err != nil {
1084 return n, c.out.setErrorLocked(err)
1085 }
1086 m, b = 1, b[1:]
1087 }
1088 }
1089
1090 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1091 return n + m, c.out.setErrorLocked(err)
1092 }
1093
1094
1095
1096 func (c *Conn) handleRenegotiation() error {
1097 msg, err := c.readHandshake()
1098 if err != nil {
1099 return err
1100 }
1101
1102 _, ok := msg.(*helloRequestMsg)
1103 if !ok {
1104 c.sendAlert(alertUnexpectedMessage)
1105 return alertUnexpectedMessage
1106 }
1107
1108 if !c.isClient {
1109 return c.sendAlert(alertNoRenegotiation)
1110 }
1111
1112 switch c.config.Renegotiation {
1113 case RenegotiateNever:
1114 return c.sendAlert(alertNoRenegotiation)
1115 case RenegotiateOnceAsClient:
1116 if c.handshakes > 1 {
1117 return c.sendAlert(alertNoRenegotiation)
1118 }
1119 case RenegotiateFreelyAsClient:
1120
1121 default:
1122 c.sendAlert(alertInternalError)
1123 return errors.New("tls: unknown Renegotiation value")
1124 }
1125
1126 c.handshakeMutex.Lock()
1127 defer c.handshakeMutex.Unlock()
1128
1129 c.handshakeComplete = false
1130 if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
1131 c.handshakes++
1132 }
1133 return c.handshakeErr
1134 }
1135
1136
1137
1138 func (c *Conn) Read(b []byte) (n int, err error) {
1139 if err = c.Handshake(); err != nil {
1140 return
1141 }
1142 if len(b) == 0 {
1143
1144
1145 return
1146 }
1147
1148 c.in.Lock()
1149 defer c.in.Unlock()
1150
1151
1152
1153 const maxConsecutiveEmptyRecords = 100
1154 for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ {
1155 for c.input == nil && c.in.err == nil {
1156 if err := c.readRecord(recordTypeApplicationData); err != nil {
1157
1158 return 0, err
1159 }
1160 if c.hand.Len() > 0 {
1161
1162
1163 if err := c.handleRenegotiation(); err != nil {
1164 return 0, err
1165 }
1166 }
1167 }
1168 if err := c.in.err; err != nil {
1169 return 0, err
1170 }
1171
1172 n, err = c.input.Read(b)
1173 if c.input.off >= len(c.input.data) {
1174 c.in.freeBlock(c.input)
1175 c.input = nil
1176 }
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189 if ri := c.rawInput; ri != nil &&
1190 n != 0 && err == nil &&
1191 c.input == nil && len(ri.data) > 0 && recordType(ri.data[0]) == recordTypeAlert {
1192 if recErr := c.readRecord(recordTypeApplicationData); recErr != nil {
1193 err = recErr
1194 }
1195 }
1196
1197 if n != 0 || err != nil {
1198 return n, err
1199 }
1200 }
1201
1202 return 0, io.ErrNoProgress
1203 }
1204
1205
1206 func (c *Conn) Close() error {
1207
1208 var x int32
1209 for {
1210 x = atomic.LoadInt32(&c.activeCall)
1211 if x&1 != 0 {
1212 return errClosed
1213 }
1214 if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) {
1215 break
1216 }
1217 }
1218 if x != 0 {
1219
1220
1221
1222
1223
1224
1225 return c.conn.Close()
1226 }
1227
1228 var alertErr error
1229
1230 c.handshakeMutex.Lock()
1231 if c.handshakeComplete {
1232 alertErr = c.closeNotify()
1233 }
1234 c.handshakeMutex.Unlock()
1235
1236 if err := c.conn.Close(); err != nil {
1237 return err
1238 }
1239 return alertErr
1240 }
1241
1242 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1243
1244
1245
1246
1247 func (c *Conn) CloseWrite() error {
1248 c.handshakeMutex.Lock()
1249 defer c.handshakeMutex.Unlock()
1250 if !c.handshakeComplete {
1251 return errEarlyCloseWrite
1252 }
1253
1254 return c.closeNotify()
1255 }
1256
1257 func (c *Conn) closeNotify() error {
1258 c.out.Lock()
1259 defer c.out.Unlock()
1260
1261 if !c.closeNotifySent {
1262 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1263 c.closeNotifySent = true
1264 }
1265 return c.closeNotifyErr
1266 }
1267
1268
1269
1270
1271
1272 func (c *Conn) Handshake() error {
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295 c.handshakeMutex.Lock()
1296 defer c.handshakeMutex.Unlock()
1297
1298 for {
1299 if err := c.handshakeErr; err != nil {
1300 return err
1301 }
1302 if c.handshakeComplete {
1303 return nil
1304 }
1305 if c.handshakeCond == nil {
1306 break
1307 }
1308
1309 c.handshakeCond.Wait()
1310 }
1311
1312
1313
1314 c.handshakeCond = sync.NewCond(&c.handshakeMutex)
1315 c.handshakeMutex.Unlock()
1316
1317 c.in.Lock()
1318 defer c.in.Unlock()
1319
1320 c.handshakeMutex.Lock()
1321
1322
1323
1324 if c.handshakeErr != nil || c.handshakeComplete {
1325 panic("handshake should not have been able to complete after handshakeCond was set")
1326 }
1327
1328 if c.isClient {
1329 c.handshakeErr = c.clientHandshake()
1330 } else {
1331 c.handshakeErr = c.serverHandshake()
1332 }
1333 if c.handshakeErr == nil {
1334 c.handshakes++
1335 } else {
1336
1337
1338 c.flush()
1339 }
1340
1341 if c.handshakeErr == nil && !c.handshakeComplete {
1342 panic("handshake should have had a result.")
1343 }
1344
1345
1346
1347 c.handshakeCond.Broadcast()
1348 c.handshakeCond = nil
1349
1350 return c.handshakeErr
1351 }
1352
1353
1354 func (c *Conn) ConnectionState() ConnectionState {
1355 c.handshakeMutex.Lock()
1356 defer c.handshakeMutex.Unlock()
1357
1358 var state ConnectionState
1359 state.HandshakeComplete = c.handshakeComplete
1360 state.ServerName = c.serverName
1361
1362 if c.handshakeComplete {
1363 state.Version = c.vers
1364 state.NegotiatedProtocol = c.clientProtocol
1365 state.DidResume = c.didResume
1366 state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
1367 state.CipherSuite = c.cipherSuite
1368 state.PeerCertificates = c.peerCertificates
1369 state.VerifiedChains = c.verifiedChains
1370 state.SignedCertificateTimestamps = c.scts
1371 state.OCSPResponse = c.ocspResponse
1372 if !c.didResume {
1373 if c.clientFinishedIsFirst {
1374 state.TLSUnique = c.clientFinished[:]
1375 } else {
1376 state.TLSUnique = c.serverFinished[:]
1377 }
1378 }
1379 }
1380
1381 return state
1382 }
1383
1384
1385
1386 func (c *Conn) OCSPResponse() []byte {
1387 c.handshakeMutex.Lock()
1388 defer c.handshakeMutex.Unlock()
1389
1390 return c.ocspResponse
1391 }
1392
1393
1394
1395
1396 func (c *Conn) VerifyHostname(host string) error {
1397 c.handshakeMutex.Lock()
1398 defer c.handshakeMutex.Unlock()
1399 if !c.isClient {
1400 return errors.New("tls: VerifyHostname called on TLS server connection")
1401 }
1402 if !c.handshakeComplete {
1403 return errors.New("tls: handshake has not yet been performed")
1404 }
1405 if len(c.verifiedChains) == 0 {
1406 return errors.New("tls: handshake did not verify certificate chain")
1407 }
1408 return c.peerCertificates[0].VerifyHostname(host)
1409 }
1410
View as plain text