diff --git a/jar.go b/jar.go index 5ff2733..fd0a1d8 100644 --- a/jar.go +++ b/jar.go @@ -72,6 +72,13 @@ type Options struct { // (useful for tests). If this is true, the value of Filename will be // ignored. NoPersist bool + + // PersistSessionCookies to disk. By default, session cookies (i.e. those + // without a max-age or expiry time) are only stored in-memory. If true, + // this flag allows session cookies to be loaded from and stored to disk. + // If a file containing session cookies is loaded and this flag is false, + // those cookies will be ignored. + PersistSessionCookies bool } // Jar implements the http.CookieJar interface from the net/http package. @@ -87,6 +94,13 @@ type Jar struct { // entries is a set of entries, keyed by their eTLD+1 and subkeyed by // their name/domain/path. entries map[string]map[string]entry + + // persistSessionCookies to disk. By default, session cookies (i.e. those + // without a max-age or expiry time) are only stored in-memory. If true, + // this flag allows session cookies to be loaded from and stored to disk. + // If a file containing session cookies is loaded and this flag is false, + // those cookies will be ignored. + persistSessionCookies bool } var noOptions Options @@ -108,6 +122,7 @@ func newAtTime(o *Options, now time.Time) (*Jar, error) { if o == nil { o = &noOptions } + jar.persistSessionCookies = o.PersistSessionCookies if jar.psList = o.PublicSuffixList; jar.psList == nil { jar.psList = publicsuffix.List } @@ -369,6 +384,9 @@ func (j *Jar) merge(entries []entry) { if e.CanonicalHost == "" { continue } + if !j.persistSessionCookies && !e.Persistent { + continue + } key := jarKey(e.CanonicalHost, j.psList) id := e.id() submap := j.entries[key] diff --git a/jar_test.go b/jar_test.go index 5ea3467..cfdea6d 100644 --- a/jar_test.go +++ b/jar_test.go @@ -51,11 +51,12 @@ func (emptyPSL) PublicSuffix(d string) string { } // newTestJar creates an empty Jar with testPSL as the public suffix list. -func newTestJar(path string) *Jar { +func newTestJar(path string, persistSessionCookies bool) *Jar { jar, err := New(&Options{ - PublicSuffixList: testPSL{}, - Filename: path, - NoPersist: path == "", + PublicSuffixList: testPSL{}, + Filename: path, + NoPersist: path == "", + PersistSessionCookies: persistSessionCookies, }) if err != nil { panic(err) @@ -306,7 +307,7 @@ var domainAndTypeTests = [...]struct { } func TestDomainAndType(t *testing.T) { - jar := newTestJar("") + jar := newTestJar("", false) for _, tc := range domainAndTypeTests { domain, hostOnly, err := jar.domainAndType(tc.host, tc.domain) if err != tc.wantErr { @@ -518,7 +519,7 @@ var basicsTests = [...]jarTest{ func TestBasics(t *testing.T) { for _, test := range basicsTests { - jar := newTestJar("") + jar := newTestJar("", false) test.run(t, jar) } } @@ -666,14 +667,14 @@ var updateAndDeleteTests = [...]jarTest{ } func TestUpdateAndDelete(t *testing.T) { - jar := newTestJar("") + jar := newTestJar("", false) for _, test := range updateAndDeleteTests { test.run(t, jar) } } func TestExpiration(t *testing.T) { - jar := newTestJar("") + jar := newTestJar("", false) jarTest{ "Expiration.", "http://www.host.test", @@ -901,7 +902,7 @@ var chromiumBasicsTests = [...]jarTest{ func TestChromiumBasics(t *testing.T) { for _, test := range chromiumBasicsTests { - jar := newTestJar("") + jar := newTestJar("", false) test.run(t, jar) } } @@ -961,7 +962,7 @@ var chromiumDomainTests = [...]jarTest{ } func TestChromiumDomain(t *testing.T) { - jar := newTestJar("") + jar := newTestJar("", false) for _, test := range chromiumDomainTests { test.run(t, jar) } @@ -1029,7 +1030,7 @@ var chromiumDeletionTests = [...]jarTest{ } func TestChromiumDeletion(t *testing.T) { - jar := newTestJar("") + jar := newTestJar("", false) for _, test := range chromiumDeletionTests { test.run(t, jar) } @@ -1204,7 +1205,7 @@ var domainHandlingTests = [...]jarTest{ func TestDomainHandling(t *testing.T) { for _, test := range domainHandlingTests { - jar := newTestJar("") + jar := newTestJar("", false) test.run(t, jar) } } @@ -1220,12 +1221,14 @@ func (c mergeCookie) set(jar *Jar) { } var mergeTests = []struct { - description string - setCookies0 []mergeCookie - setCookies1 []mergeCookie - now time.Time - content string - queries []query // Queries to test the Jar.Cookies method + description string + setCookies0 []mergeCookie + setCookies1 []mergeCookie + persistSessionCookies0 bool + persistSessionCookies1 bool + now time.Time + content string + queries []query // Queries to test the Jar.Cookies method }{{ description: "empty jar1", setCookies0: []mergeCookie{ @@ -1347,6 +1350,82 @@ var mergeTests = []struct { {"http://nowhere.com", "A=n"}, {"http://www.elsewhere", "X=x"}, }, +}, { + description: "empty jar1", + setCookies0: []mergeCookie{ + {atTime(0), "http://www.host.test", "A=a; max-age=10"}, + }, + now: atTime(1), + content: "A=a", +}, { + description: "empty jar0", + setCookies1: []mergeCookie{ + {atTime(0), "http://www.host.test", "A=a; max-age=10"}, + }, + now: atTime(1), + content: "A=a", +}, { + description: "simple override (1)", + setCookies0: []mergeCookie{ + {atTime(0), "http://www.host.test", "A=a; max-age=10"}, + }, + setCookies1: []mergeCookie{ + {atTime(1), "http://www.host.test", "A=b; max-age=10"}, + }, + now: atTime(2), + content: "A=b", +}, { + description: "simple override (2)", + setCookies0: []mergeCookie{ + {atTime(1), "http://www.host.test", "A=a; max-age=10"}, + }, + setCookies1: []mergeCookie{ + {atTime(0), "http://www.host.test", "A=b; max-age=10"}, + }, + now: atTime(2), + content: "A=a", +}, { + description: "session cookie persistence, empty jar1", + setCookies0: []mergeCookie{ + {atTime(0), "http://www.host.test", "A=a"}, + }, + persistSessionCookies0: true, + persistSessionCookies1: true, + now: atTime(1), + content: "A=a", +}, { + description: "session cookie persistence, empty jar0", + setCookies1: []mergeCookie{ + {atTime(0), "http://www.host.test", "A=a"}, + }, + persistSessionCookies0: true, + persistSessionCookies1: true, + now: atTime(1), + content: "A=a", +}, { + description: "simple override (1)", + setCookies0: []mergeCookie{ + {atTime(0), "http://www.host.test", "A=a"}, + }, + setCookies1: []mergeCookie{ + {atTime(1), "http://www.host.test", "A=b"}, + }, + persistSessionCookies0: true, + persistSessionCookies1: true, + now: atTime(2), + content: "A=b", +}, { + description: "simple override (2)", + setCookies0: []mergeCookie{ + {atTime(1), "http://www.host.test", "A=a"}, + }, + setCookies1: []mergeCookie{ + {atTime(0), "http://www.host.test", "A=b"}, + }, + persistSessionCookies0: true, + persistSessionCookies1: true, + now: atTime(2), + content: "A=a", }} func TestSaveMerge(t *testing.T) { @@ -1357,11 +1436,11 @@ func TestSaveMerge(t *testing.T) { defer os.RemoveAll(dir) for i, test := range mergeTests { path := filepath.Join(dir, fmt.Sprintf("jar%d", i)) - jar0 := newTestJar(path) + jar0 := newTestJar(path, test.persistSessionCookies0) for _, sc := range test.setCookies0 { sc.set(jar0) } - jar1 := newTestJar(path) + jar1 := newTestJar(path, test.persistSessionCookies1) for _, sc := range test.setCookies1 { sc.set(jar1) } @@ -1398,8 +1477,8 @@ func TestMergeConcurrent(t *testing.T) { defer os.Remove(f.Name()) defer f.Close() - jar0 := newTestJar(f.Name()) - jar1 := newTestJar(f.Name()) + jar0 := newTestJar(f.Name(), false) + jar1 := newTestJar(f.Name(), false) var wg sync.WaitGroup url := mustParseURL("http://foo.com") merger := func(j *Jar) { @@ -1438,7 +1517,7 @@ func TestMergeConcurrent(t *testing.T) { func TestDeleteExpired(t *testing.T) { expirySeconds := int(expiryRemovalDuration / time.Second) - jar := newTestJar("") + jar := newTestJar("", false) now := tNow setCookies(jar, "http://foo.com", []string{ @@ -1499,20 +1578,54 @@ func TestLoadSave(t *testing.T) { c.Assert(err, qt.Equals, nil) defer os.RemoveAll(d) file := filepath.Join(d, "cookies") - j := newTestJar(file) + j := newTestJar(file, false) j.SetCookies(serializeTestURL, serializeTestCookies) err = j.Save() c.Assert(err, qt.Equals, nil) _, err = os.Stat(file) c.Assert(err, qt.Equals, nil) - j1 := newTestJar(file) + j1 := newTestJar(file, false) c.Assert(len(j1.entries), qt.Equals, len(serializeTestCookies)) c.Assert(j1.entries, qt.DeepEquals, j.entries) } +var serializeTestSessionCookies = []*http.Cookie{{ + Name: "foo", + Value: "bar", + Path: "/p", + Domain: "example.com", + Secure: true, + HttpOnly: true, + Raw: "raw string", + Unparsed: []string{"x", "y", "z"}, +}} + +func TestLoadSaveSessionCookies(t *testing.T) { + c := qt.New(t) + d, err := ioutil.TempDir("", "") + c.Assert(err, qt.Equals, nil) + defer os.RemoveAll(d) + file := filepath.Join(d, "cookies") + j := newTestJar(file, true) + j.SetCookies(serializeTestURL, serializeTestSessionCookies) + err = j.Save() + c.Assert(err, qt.Equals, nil) + _, err = os.Stat(file) + c.Assert(err, qt.Equals, nil) + + // load without session cookie persistence enabled + j1 := newTestJar(file, false) + c.Assert(len(j1.entries), qt.Equals, 0) + + // load with session cookie persistence enabled + j2 := newTestJar(file, true) + c.Assert(len(j2.entries), qt.Equals, len(serializeTestCookies)) + c.Assert(j2.entries, qt.DeepEquals, j.entries) +} + func TestMarshalJSON(t *testing.T) { c := qt.New(t) - j := newTestJar("") + j := newTestJar("", false) j.SetCookies(serializeTestURL, serializeTestCookies) // Marshal the cookies. data, err := j.MarshalJSON() @@ -1525,7 +1638,7 @@ func TestMarshalJSON(t *testing.T) { err = ioutil.WriteFile(file, data, 0600) c.Assert(err, qt.Equals, nil) // Load cookies from the file. - j1 := newTestJar(file) + j1 := newTestJar(file, false) c.Assert(len(j1.entries), qt.Equals, len(serializeTestCookies)) c.Assert(j1.entries, qt.DeepEquals, j.entries) } @@ -1539,7 +1652,7 @@ func TestLoadSaveWithNoPersist(t *testing.T) { } defer os.RemoveAll(d) file := filepath.Join(d, "cookies") - j := newTestJar(file) + j := newTestJar(file, false) j.SetCookies(serializeTestURL, serializeTestCookies) if err := j.Save(); err != nil { t.Fatalf("cannot save: %v", err) @@ -2000,7 +2113,7 @@ func TestAllCookies(t *testing.T) { defer os.RemoveAll(dir) for i, test := range allCookiesTests { path := filepath.Join(dir, fmt.Sprintf("jar%d", i)) - jar := newTestJar(path) + jar := newTestJar(path, false) for _, s := range test.set { jar.setCookies(s.url, s.cookies, tNow) } @@ -2017,7 +2130,7 @@ func TestAllCookies(t *testing.T) { } func TestRemoveCookies(t *testing.T) { - jar := newTestJar("") + jar := newTestJar("", false) jar.SetCookies( mustParseURL("https://www.google.com"), []*http.Cookie{ @@ -2068,7 +2181,7 @@ func TestRemoveAllHostIP(t *testing.T) { } func testRemoveAllHost(t *testing.T, setURL *url.URL, removeHost string, shouldRemove bool) { - jar := newTestJar("") + jar := newTestJar("", false) google := mustParseURL("https://www.google.com") jar.SetCookies( google, @@ -2129,7 +2242,7 @@ func testRemoveAllHost(t *testing.T, setURL *url.URL, removeHost string, shouldR } func TestRemoveAll(t *testing.T) { - jar := newTestJar("") + jar := newTestJar("", false) jar.SetCookies( mustParseURL("https://www.google.com"), []*http.Cookie{ diff --git a/serialize.go b/serialize.go index 2792dfb..d8cb788 100644 --- a/serialize.go +++ b/serialize.go @@ -35,7 +35,7 @@ func (j *Jar) MarshalJSON() ([]byte, error) { j.mu.Lock() defer j.mu.Unlock() // Marshaling entries can never fail. - data, _ := json.Marshal(j.allPersistentEntries()) + data, _ := json.Marshal(j.allEntriesToPersist()) return data, nil } @@ -124,20 +124,20 @@ func (j *Jar) mergeFrom(r io.Reader) error { // as a JSON array. func (j *Jar) writeTo(w io.Writer) error { encoder := json.NewEncoder(w) - entries := j.allPersistentEntries() + entries := j.allEntriesToPersist() if err := encoder.Encode(entries); err != nil { return err } return nil } -// allPersistentEntries returns all the entries in the jar, sorted by primarly by canonical host +// allEntriesToPersist returns all the entries in the jar, sorted by primarly by canonical host // name and secondarily by path length. -func (j *Jar) allPersistentEntries() []entry { +func (j *Jar) allEntriesToPersist() []entry { var entries []entry for _, submap := range j.entries { for _, e := range submap { - if e.Persistent { + if j.persistSessionCookies || e.Persistent { entries = append(entries, e) } }