1
2
3
4
5
6 package cookiejar
7
8 import (
9 "errors"
10 "fmt"
11 "net"
12 "net/http"
13 "net/url"
14 "sort"
15 "strings"
16 "sync"
17 "time"
18 )
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34 type PublicSuffixList interface {
35
36
37
38
39
40 PublicSuffix(domain string) string
41
42
43
44
45 String() string
46 }
47
48
49 type Options struct {
50
51
52
53
54
55
56 PublicSuffixList PublicSuffixList
57 }
58
59
60 type Jar struct {
61 psList PublicSuffixList
62
63
64 mu sync.Mutex
65
66
67
68 entries map[string]map[string]entry
69
70
71
72 nextSeqNum uint64
73 }
74
75
76
77 func New(o *Options) (*Jar, error) {
78 jar := &Jar{
79 entries: make(map[string]map[string]entry),
80 }
81 if o != nil {
82 jar.psList = o.PublicSuffixList
83 }
84 return jar, nil
85 }
86
87
88
89
90
91 type entry struct {
92 Name string
93 Value string
94 Domain string
95 Path string
96 Secure bool
97 HttpOnly bool
98 Persistent bool
99 HostOnly bool
100 Expires time.Time
101 Creation time.Time
102 LastAccess time.Time
103
104
105
106
107 seqNum uint64
108 }
109
110
111 func (e *entry) id() string {
112 return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
113 }
114
115
116
117
118 func (e *entry) shouldSend(https bool, host, path string) bool {
119 return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure)
120 }
121
122
123 func (e *entry) domainMatch(host string) bool {
124 if e.Domain == host {
125 return true
126 }
127 return !e.HostOnly && hasDotSuffix(host, e.Domain)
128 }
129
130
131 func (e *entry) pathMatch(requestPath string) bool {
132 if requestPath == e.Path {
133 return true
134 }
135 if strings.HasPrefix(requestPath, e.Path) {
136 if e.Path[len(e.Path)-1] == '/' {
137 return true
138 } else if requestPath[len(e.Path)] == '/' {
139 return true
140 }
141 }
142 return false
143 }
144
145
146 func hasDotSuffix(s, suffix string) bool {
147 return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
148 }
149
150
151
152
153 func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
154 return j.cookies(u, time.Now())
155 }
156
157
158 func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
159 if u.Scheme != "http" && u.Scheme != "https" {
160 return cookies
161 }
162 host, err := canonicalHost(u.Host)
163 if err != nil {
164 return cookies
165 }
166 key := jarKey(host, j.psList)
167
168 j.mu.Lock()
169 defer j.mu.Unlock()
170
171 submap := j.entries[key]
172 if submap == nil {
173 return cookies
174 }
175
176 https := u.Scheme == "https"
177 path := u.Path
178 if path == "" {
179 path = "/"
180 }
181
182 modified := false
183 var selected []entry
184 for id, e := range submap {
185 if e.Persistent && !e.Expires.After(now) {
186 delete(submap, id)
187 modified = true
188 continue
189 }
190 if !e.shouldSend(https, host, path) {
191 continue
192 }
193 e.LastAccess = now
194 submap[id] = e
195 selected = append(selected, e)
196 modified = true
197 }
198 if modified {
199 if len(submap) == 0 {
200 delete(j.entries, key)
201 } else {
202 j.entries[key] = submap
203 }
204 }
205
206
207
208 sort.Slice(selected, func(i, j int) bool {
209 s := selected
210 if len(s[i].Path) != len(s[j].Path) {
211 return len(s[i].Path) > len(s[j].Path)
212 }
213 if !s[i].Creation.Equal(s[j].Creation) {
214 return s[i].Creation.Before(s[j].Creation)
215 }
216 return s[i].seqNum < s[j].seqNum
217 })
218 for _, e := range selected {
219 cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value})
220 }
221
222 return cookies
223 }
224
225
226
227
228 func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
229 j.setCookies(u, cookies, time.Now())
230 }
231
232
233 func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
234 if len(cookies) == 0 {
235 return
236 }
237 if u.Scheme != "http" && u.Scheme != "https" {
238 return
239 }
240 host, err := canonicalHost(u.Host)
241 if err != nil {
242 return
243 }
244 key := jarKey(host, j.psList)
245 defPath := defaultPath(u.Path)
246
247 j.mu.Lock()
248 defer j.mu.Unlock()
249
250 submap := j.entries[key]
251
252 modified := false
253 for _, cookie := range cookies {
254 e, remove, err := j.newEntry(cookie, now, defPath, host)
255 if err != nil {
256 continue
257 }
258 id := e.id()
259 if remove {
260 if submap != nil {
261 if _, ok := submap[id]; ok {
262 delete(submap, id)
263 modified = true
264 }
265 }
266 continue
267 }
268 if submap == nil {
269 submap = make(map[string]entry)
270 }
271
272 if old, ok := submap[id]; ok {
273 e.Creation = old.Creation
274 e.seqNum = old.seqNum
275 } else {
276 e.Creation = now
277 e.seqNum = j.nextSeqNum
278 j.nextSeqNum++
279 }
280 e.LastAccess = now
281 submap[id] = e
282 modified = true
283 }
284
285 if modified {
286 if len(submap) == 0 {
287 delete(j.entries, key)
288 } else {
289 j.entries[key] = submap
290 }
291 }
292 }
293
294
295
296 func canonicalHost(host string) (string, error) {
297 var err error
298 host = strings.ToLower(host)
299 if hasPort(host) {
300 host, _, err = net.SplitHostPort(host)
301 if err != nil {
302 return "", err
303 }
304 }
305 if strings.HasSuffix(host, ".") {
306
307 host = host[:len(host)-1]
308 }
309 return toASCII(host)
310 }
311
312
313
314 func hasPort(host string) bool {
315 colons := strings.Count(host, ":")
316 if colons == 0 {
317 return false
318 }
319 if colons == 1 {
320 return true
321 }
322 return host[0] == '[' && strings.Contains(host, "]:")
323 }
324
325
326 func jarKey(host string, psl PublicSuffixList) string {
327 if isIP(host) {
328 return host
329 }
330
331 var i int
332 if psl == nil {
333 i = strings.LastIndex(host, ".")
334 if i <= 0 {
335 return host
336 }
337 } else {
338 suffix := psl.PublicSuffix(host)
339 if suffix == host {
340 return host
341 }
342 i = len(host) - len(suffix)
343 if i <= 0 || host[i-1] != '.' {
344
345
346 return host
347 }
348
349
350
351 }
352 prevDot := strings.LastIndex(host[:i-1], ".")
353 return host[prevDot+1:]
354 }
355
356
357 func isIP(host string) bool {
358 return net.ParseIP(host) != nil
359 }
360
361
362
363 func defaultPath(path string) string {
364 if len(path) == 0 || path[0] != '/' {
365 return "/"
366 }
367
368 i := strings.LastIndex(path, "/")
369 if i == 0 {
370 return "/"
371 }
372 return path[:i]
373 }
374
375
376
377
378
379
380
381
382
383
384 func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
385 e.Name = c.Name
386
387 if c.Path == "" || c.Path[0] != '/' {
388 e.Path = defPath
389 } else {
390 e.Path = c.Path
391 }
392
393 e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
394 if err != nil {
395 return e, false, err
396 }
397
398
399 if c.MaxAge < 0 {
400 return e, true, nil
401 } else if c.MaxAge > 0 {
402 e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
403 e.Persistent = true
404 } else {
405 if c.Expires.IsZero() {
406 e.Expires = endOfTime
407 e.Persistent = false
408 } else {
409 if !c.Expires.After(now) {
410 return e, true, nil
411 }
412 e.Expires = c.Expires
413 e.Persistent = true
414 }
415 }
416
417 e.Value = c.Value
418 e.Secure = c.Secure
419 e.HttpOnly = c.HttpOnly
420
421 return e, false, nil
422 }
423
424 var (
425 errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
426 errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
427 errNoHostname = errors.New("cookiejar: no host name available (IP only)")
428 )
429
430
431
432
433 var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
434
435
436 func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
437 if domain == "" {
438
439
440 return host, true, nil
441 }
442
443 if isIP(host) {
444
445
446
447 return "", false, errNoHostname
448 }
449
450
451
452
453 if domain[0] == '.' {
454 domain = domain[1:]
455 }
456
457 if len(domain) == 0 || domain[0] == '.' {
458
459
460 return "", false, errMalformedDomain
461 }
462 domain = strings.ToLower(domain)
463
464 if domain[len(domain)-1] == '.' {
465
466
467
468
469
470
471 return "", false, errMalformedDomain
472 }
473
474
475 if j.psList != nil {
476 if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
477 if host == domain {
478
479
480 return host, true, nil
481 }
482 return "", false, errIllegalDomain
483 }
484 }
485
486
487
488 if host != domain && !hasDotSuffix(host, domain) {
489 return "", false, errIllegalDomain
490 }
491
492 return domain, false, nil
493 }
494
View as plain text