...
1
2
3
4
5 package httptest
6
7 import (
8 "bytes"
9 "fmt"
10 "io/ioutil"
11 "net/http"
12 "strconv"
13 "strings"
14 )
15
16
17
18 type ResponseRecorder struct {
19
20
21
22
23
24
25 Code int
26
27
28
29
30
31 HeaderMap http.Header
32
33
34
35 Body *bytes.Buffer
36
37
38 Flushed bool
39
40 result *http.Response
41 snapHeader http.Header
42 wroteHeader bool
43 }
44
45
46 func NewRecorder() *ResponseRecorder {
47 return &ResponseRecorder{
48 HeaderMap: make(http.Header),
49 Body: new(bytes.Buffer),
50 Code: 200,
51 }
52 }
53
54
55
56 const DefaultRemoteAddr = "1.2.3.4"
57
58
59 func (rw *ResponseRecorder) Header() http.Header {
60 m := rw.HeaderMap
61 if m == nil {
62 m = make(http.Header)
63 rw.HeaderMap = m
64 }
65 return m
66 }
67
68
69
70
71
72
73
74
75 func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
76 if rw.wroteHeader {
77 return
78 }
79 if len(str) > 512 {
80 str = str[:512]
81 }
82
83 m := rw.Header()
84
85 _, hasType := m["Content-Type"]
86 hasTE := m.Get("Transfer-Encoding") != ""
87 if !hasType && !hasTE {
88 if b == nil {
89 b = []byte(str)
90 }
91 m.Set("Content-Type", http.DetectContentType(b))
92 }
93
94 rw.WriteHeader(200)
95 }
96
97
98 func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
99 rw.writeHeader(buf, "")
100 if rw.Body != nil {
101 rw.Body.Write(buf)
102 }
103 return len(buf), nil
104 }
105
106
107 func (rw *ResponseRecorder) WriteString(str string) (int, error) {
108 rw.writeHeader(nil, str)
109 if rw.Body != nil {
110 rw.Body.WriteString(str)
111 }
112 return len(str), nil
113 }
114
115
116
117 func (rw *ResponseRecorder) WriteHeader(code int) {
118 if rw.wroteHeader {
119 return
120 }
121 rw.Code = code
122 rw.wroteHeader = true
123 if rw.HeaderMap == nil {
124 rw.HeaderMap = make(http.Header)
125 }
126 rw.snapHeader = cloneHeader(rw.HeaderMap)
127 }
128
129 func cloneHeader(h http.Header) http.Header {
130 h2 := make(http.Header, len(h))
131 for k, vv := range h {
132 vv2 := make([]string, len(vv))
133 copy(vv2, vv)
134 h2[k] = vv2
135 }
136 return h2
137 }
138
139
140 func (rw *ResponseRecorder) Flush() {
141 if !rw.wroteHeader {
142 rw.WriteHeader(200)
143 }
144 rw.Flushed = true
145 }
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162 func (rw *ResponseRecorder) Result() *http.Response {
163 if rw.result != nil {
164 return rw.result
165 }
166 if rw.snapHeader == nil {
167 rw.snapHeader = cloneHeader(rw.HeaderMap)
168 }
169 res := &http.Response{
170 Proto: "HTTP/1.1",
171 ProtoMajor: 1,
172 ProtoMinor: 1,
173 StatusCode: rw.Code,
174 Header: rw.snapHeader,
175 }
176 rw.result = res
177 if res.StatusCode == 0 {
178 res.StatusCode = 200
179 }
180 res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode))
181 if rw.Body != nil {
182 res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
183 }
184 res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
185
186 if trailers, ok := rw.snapHeader["Trailer"]; ok {
187 res.Trailer = make(http.Header, len(trailers))
188 for _, k := range trailers {
189
190
191
192 switch k {
193 case "Transfer-Encoding", "Content-Length", "Trailer":
194
195 continue
196 }
197 k = http.CanonicalHeaderKey(k)
198 vv, ok := rw.HeaderMap[k]
199 if !ok {
200 continue
201 }
202 vv2 := make([]string, len(vv))
203 copy(vv2, vv)
204 res.Trailer[k] = vv2
205 }
206 }
207 for k, vv := range rw.HeaderMap {
208 if !strings.HasPrefix(k, http.TrailerPrefix) {
209 continue
210 }
211 if res.Trailer == nil {
212 res.Trailer = make(http.Header)
213 }
214 for _, v := range vv {
215 res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v)
216 }
217 }
218 return res
219 }
220
221
222
223
224
225
226 func parseContentLength(cl string) int64 {
227 cl = strings.TrimSpace(cl)
228 if cl == "" {
229 return -1
230 }
231 n, err := strconv.ParseInt(cl, 10, 64)
232 if err != nil {
233 return -1
234 }
235 return n
236 }
237
View as plain text