1
2
3
4
5
6
7 package httptest
8
9 import (
10 "bytes"
11 "crypto/tls"
12 "crypto/x509"
13 "flag"
14 "fmt"
15 "log"
16 "net"
17 "net/http"
18 "net/http/internal"
19 "os"
20 "sync"
21 "time"
22 )
23
24
25
26 type Server struct {
27 URL string
28 Listener net.Listener
29
30
31
32
33 TLS *tls.Config
34
35
36
37 Config *http.Server
38
39
40 certificate *x509.Certificate
41
42
43
44 wg sync.WaitGroup
45
46 mu sync.Mutex
47 closed bool
48 conns map[net.Conn]http.ConnState
49
50
51
52 client *http.Client
53 }
54
55 func newLocalListener() net.Listener {
56 if *serve != "" {
57 l, err := net.Listen("tcp", *serve)
58 if err != nil {
59 panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err))
60 }
61 return l
62 }
63 l, err := net.Listen("tcp", "127.0.0.1:0")
64 if err != nil {
65 if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
66 panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
67 }
68 }
69 return l
70 }
71
72
73
74
75
76 var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks")
77
78
79
80 func NewServer(handler http.Handler) *Server {
81 ts := NewUnstartedServer(handler)
82 ts.Start()
83 return ts
84 }
85
86
87
88
89
90
91
92 func NewUnstartedServer(handler http.Handler) *Server {
93 return &Server{
94 Listener: newLocalListener(),
95 Config: &http.Server{Handler: handler},
96 }
97 }
98
99
100 func (s *Server) Start() {
101 if s.URL != "" {
102 panic("Server already started")
103 }
104 if s.client == nil {
105 s.client = &http.Client{Transport: &http.Transport{}}
106 }
107 s.URL = "http://" + s.Listener.Addr().String()
108 s.wrap()
109 s.goServe()
110 if *serve != "" {
111 fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
112 select {}
113 }
114 }
115
116
117 func (s *Server) StartTLS() {
118 if s.URL != "" {
119 panic("Server already started")
120 }
121 if s.client == nil {
122 s.client = &http.Client{Transport: &http.Transport{}}
123 }
124 cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
125 if err != nil {
126 panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
127 }
128
129 existingConfig := s.TLS
130 if existingConfig != nil {
131 s.TLS = existingConfig.Clone()
132 } else {
133 s.TLS = new(tls.Config)
134 }
135 if s.TLS.NextProtos == nil {
136 s.TLS.NextProtos = []string{"http/1.1"}
137 }
138 if len(s.TLS.Certificates) == 0 {
139 s.TLS.Certificates = []tls.Certificate{cert}
140 }
141 s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
142 if err != nil {
143 panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
144 }
145 certpool := x509.NewCertPool()
146 certpool.AddCert(s.certificate)
147 s.client.Transport = &http.Transport{
148 TLSClientConfig: &tls.Config{
149 RootCAs: certpool,
150 },
151 }
152 s.Listener = tls.NewListener(s.Listener, s.TLS)
153 s.URL = "https://" + s.Listener.Addr().String()
154 s.wrap()
155 s.goServe()
156 }
157
158
159
160 func NewTLSServer(handler http.Handler) *Server {
161 ts := NewUnstartedServer(handler)
162 ts.StartTLS()
163 return ts
164 }
165
166 type closeIdleTransport interface {
167 CloseIdleConnections()
168 }
169
170
171
172 func (s *Server) Close() {
173 s.mu.Lock()
174 if !s.closed {
175 s.closed = true
176 s.Listener.Close()
177 s.Config.SetKeepAlivesEnabled(false)
178 for c, st := range s.conns {
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197 if st == http.StateIdle || st == http.StateNew {
198 s.closeConn(c)
199 }
200 }
201
202 t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
203 defer t.Stop()
204 }
205 s.mu.Unlock()
206
207
208
209
210 if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
211 t.CloseIdleConnections()
212 }
213
214
215 if s.client != nil {
216 if t, ok := s.client.Transport.(closeIdleTransport); ok {
217 t.CloseIdleConnections()
218 }
219 }
220
221 s.wg.Wait()
222 }
223
224 func (s *Server) logCloseHangDebugInfo() {
225 s.mu.Lock()
226 defer s.mu.Unlock()
227 var buf bytes.Buffer
228 buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
229 for c, st := range s.conns {
230 fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
231 }
232 log.Print(buf.String())
233 }
234
235
236 func (s *Server) CloseClientConnections() {
237 s.mu.Lock()
238 nconn := len(s.conns)
239 ch := make(chan struct{}, nconn)
240 for c := range s.conns {
241 go s.closeConnChan(c, ch)
242 }
243 s.mu.Unlock()
244
245
246
247
248
249
250
251 timer := time.NewTimer(5 * time.Second)
252 defer timer.Stop()
253 for i := 0; i < nconn; i++ {
254 select {
255 case <-ch:
256 case <-timer.C:
257
258 return
259 }
260 }
261 }
262
263
264
265 func (s *Server) Certificate() *x509.Certificate {
266 return s.certificate
267 }
268
269
270
271
272 func (s *Server) Client() *http.Client {
273 return s.client
274 }
275
276 func (s *Server) goServe() {
277 s.wg.Add(1)
278 go func() {
279 defer s.wg.Done()
280 s.Config.Serve(s.Listener)
281 }()
282 }
283
284
285
286 func (s *Server) wrap() {
287 oldHook := s.Config.ConnState
288 s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
289 s.mu.Lock()
290 defer s.mu.Unlock()
291 switch cs {
292 case http.StateNew:
293 s.wg.Add(1)
294 if _, exists := s.conns[c]; exists {
295 panic("invalid state transition")
296 }
297 if s.conns == nil {
298 s.conns = make(map[net.Conn]http.ConnState)
299 }
300 s.conns[c] = cs
301 if s.closed {
302
303
304
305
306 s.closeConn(c)
307 }
308 case http.StateActive:
309 if oldState, ok := s.conns[c]; ok {
310 if oldState != http.StateNew && oldState != http.StateIdle {
311 panic("invalid state transition")
312 }
313 s.conns[c] = cs
314 }
315 case http.StateIdle:
316 if oldState, ok := s.conns[c]; ok {
317 if oldState != http.StateActive {
318 panic("invalid state transition")
319 }
320 s.conns[c] = cs
321 }
322 if s.closed {
323 s.closeConn(c)
324 }
325 case http.StateHijacked, http.StateClosed:
326 s.forgetConn(c)
327 }
328 if oldHook != nil {
329 oldHook(c, cs)
330 }
331 }
332 }
333
334
335
336 func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) }
337
338
339
340 func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
341 c.Close()
342 if done != nil {
343 done <- struct{}{}
344 }
345 }
346
347
348
349
350 func (s *Server) forgetConn(c net.Conn) {
351 if _, ok := s.conns[c]; ok {
352 delete(s.conns, c)
353 s.wg.Done()
354 }
355 }
356
View as plain text