1
2
3
4
5
6
7 package httputil
8
9 import (
10 "context"
11 "io"
12 "log"
13 "net"
14 "net/http"
15 "net/url"
16 "strings"
17 "sync"
18 "time"
19 )
20
21
22
23 var onExitFlushLoop func()
24
25
26
27
28 type ReverseProxy struct {
29
30
31
32
33
34
35 Director func(*http.Request)
36
37
38
39 Transport http.RoundTripper
40
41
42
43
44
45 FlushInterval time.Duration
46
47
48
49
50
51 ErrorLog *log.Logger
52
53
54
55
56 BufferPool BufferPool
57
58
59
60
61 ModifyResponse func(*http.Response) error
62 }
63
64
65
66 type BufferPool interface {
67 Get() []byte
68 Put([]byte)
69 }
70
71 func singleJoiningSlash(a, b string) string {
72 aslash := strings.HasSuffix(a, "/")
73 bslash := strings.HasPrefix(b, "/")
74 switch {
75 case aslash && bslash:
76 return a + b[1:]
77 case !aslash && !bslash:
78 return a + "/" + b
79 }
80 return a + b
81 }
82
83
84
85
86
87
88
89
90 func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
91 targetQuery := target.RawQuery
92 director := func(req *http.Request) {
93 req.URL.Scheme = target.Scheme
94 req.URL.Host = target.Host
95 req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
96 if targetQuery == "" || req.URL.RawQuery == "" {
97 req.URL.RawQuery = targetQuery + req.URL.RawQuery
98 } else {
99 req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
100 }
101 if _, ok := req.Header["User-Agent"]; !ok {
102
103 req.Header.Set("User-Agent", "")
104 }
105 }
106 return &ReverseProxy{Director: director}
107 }
108
109 func copyHeader(dst, src http.Header) {
110 for k, vv := range src {
111 for _, v := range vv {
112 dst.Add(k, v)
113 }
114 }
115 }
116
117 func cloneHeader(h http.Header) http.Header {
118 h2 := make(http.Header, len(h))
119 for k, vv := range h {
120 vv2 := make([]string, len(vv))
121 copy(vv2, vv)
122 h2[k] = vv2
123 }
124 return h2
125 }
126
127
128
129 var hopHeaders = []string{
130 "Connection",
131 "Proxy-Connection",
132 "Keep-Alive",
133 "Proxy-Authenticate",
134 "Proxy-Authorization",
135 "Te",
136 "Trailer",
137 "Transfer-Encoding",
138 "Upgrade",
139 }
140
141 func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
142 transport := p.Transport
143 if transport == nil {
144 transport = http.DefaultTransport
145 }
146
147 ctx := req.Context()
148 if cn, ok := rw.(http.CloseNotifier); ok {
149 var cancel context.CancelFunc
150 ctx, cancel = context.WithCancel(ctx)
151 defer cancel()
152 notifyChan := cn.CloseNotify()
153 go func() {
154 select {
155 case <-notifyChan:
156 cancel()
157 case <-ctx.Done():
158 }
159 }()
160 }
161
162 outreq := req.WithContext(ctx)
163 if req.ContentLength == 0 {
164 outreq.Body = nil
165 }
166
167 outreq.Header = cloneHeader(req.Header)
168
169 p.Director(outreq)
170 outreq.Close = false
171
172 removeConnectionHeaders(outreq.Header)
173
174
175
176
177 for _, h := range hopHeaders {
178 if outreq.Header.Get(h) != "" {
179 outreq.Header.Del(h)
180 }
181 }
182
183 if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
184
185
186
187 if prior, ok := outreq.Header["X-Forwarded-For"]; ok {
188 clientIP = strings.Join(prior, ", ") + ", " + clientIP
189 }
190 outreq.Header.Set("X-Forwarded-For", clientIP)
191 }
192
193 res, err := transport.RoundTrip(outreq)
194 if err != nil {
195 p.logf("http: proxy error: %v", err)
196 rw.WriteHeader(http.StatusBadGateway)
197 return
198 }
199
200 removeConnectionHeaders(res.Header)
201
202 for _, h := range hopHeaders {
203 res.Header.Del(h)
204 }
205
206 if p.ModifyResponse != nil {
207 if err := p.ModifyResponse(res); err != nil {
208 p.logf("http: proxy error: %v", err)
209 rw.WriteHeader(http.StatusBadGateway)
210 res.Body.Close()
211 return
212 }
213 }
214
215 copyHeader(rw.Header(), res.Header)
216
217
218
219 announcedTrailers := len(res.Trailer)
220 if announcedTrailers > 0 {
221 trailerKeys := make([]string, 0, len(res.Trailer))
222 for k := range res.Trailer {
223 trailerKeys = append(trailerKeys, k)
224 }
225 rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
226 }
227
228 rw.WriteHeader(res.StatusCode)
229 if len(res.Trailer) > 0 {
230
231
232
233 if fl, ok := rw.(http.Flusher); ok {
234 fl.Flush()
235 }
236 }
237 p.copyResponse(rw, res.Body)
238 res.Body.Close()
239
240 if len(res.Trailer) == announcedTrailers {
241 copyHeader(rw.Header(), res.Trailer)
242 return
243 }
244
245 for k, vv := range res.Trailer {
246 k = http.TrailerPrefix + k
247 for _, v := range vv {
248 rw.Header().Add(k, v)
249 }
250 }
251 }
252
253
254
255 func removeConnectionHeaders(h http.Header) {
256 if c := h.Get("Connection"); c != "" {
257 for _, f := range strings.Split(c, ",") {
258 if f = strings.TrimSpace(f); f != "" {
259 h.Del(f)
260 }
261 }
262 }
263 }
264
265 func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
266 if p.FlushInterval != 0 {
267 if wf, ok := dst.(writeFlusher); ok {
268 mlw := &maxLatencyWriter{
269 dst: wf,
270 latency: p.FlushInterval,
271 done: make(chan bool),
272 }
273 go mlw.flushLoop()
274 defer mlw.stop()
275 dst = mlw
276 }
277 }
278
279 var buf []byte
280 if p.BufferPool != nil {
281 buf = p.BufferPool.Get()
282 }
283 p.copyBuffer(dst, src, buf)
284 if p.BufferPool != nil {
285 p.BufferPool.Put(buf)
286 }
287 }
288
289 func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
290 if len(buf) == 0 {
291 buf = make([]byte, 32*1024)
292 }
293 var written int64
294 for {
295 nr, rerr := src.Read(buf)
296 if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
297 p.logf("httputil: ReverseProxy read error during body copy: %v", rerr)
298 }
299 if nr > 0 {
300 nw, werr := dst.Write(buf[:nr])
301 if nw > 0 {
302 written += int64(nw)
303 }
304 if werr != nil {
305 return written, werr
306 }
307 if nr != nw {
308 return written, io.ErrShortWrite
309 }
310 }
311 if rerr != nil {
312 return written, rerr
313 }
314 }
315 }
316
317 func (p *ReverseProxy) logf(format string, args ...interface{}) {
318 if p.ErrorLog != nil {
319 p.ErrorLog.Printf(format, args...)
320 } else {
321 log.Printf(format, args...)
322 }
323 }
324
325 type writeFlusher interface {
326 io.Writer
327 http.Flusher
328 }
329
330 type maxLatencyWriter struct {
331 dst writeFlusher
332 latency time.Duration
333
334 mu sync.Mutex
335 done chan bool
336 }
337
338 func (m *maxLatencyWriter) Write(p []byte) (int, error) {
339 m.mu.Lock()
340 defer m.mu.Unlock()
341 return m.dst.Write(p)
342 }
343
344 func (m *maxLatencyWriter) flushLoop() {
345 t := time.NewTicker(m.latency)
346 defer t.Stop()
347 for {
348 select {
349 case <-m.done:
350 if onExitFlushLoop != nil {
351 onExitFlushLoop()
352 }
353 return
354 case <-t.C:
355 m.mu.Lock()
356 m.dst.Flush()
357 m.mu.Unlock()
358 }
359 }
360 }
361
362 func (m *maxLatencyWriter) stop() { m.done <- true }
363
View as plain text