Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,10 @@ func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Reques
return p.sessionStore.Clear(rw, req)
}

func (p *OAuthProxy) ClearAllSessions(req *http.Request, session *sessionsapi.SessionState) error {
return p.sessionStore.ClearAllUserSessions(req, session)
}

// LoadCookiedSession reads the user's authentication details from the request
func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*sessionsapi.SessionState, error) {
return p.sessionStore.Load(req)
Expand Down Expand Up @@ -775,6 +779,17 @@ func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request, signOutA
return
}
err = p.ClearSessionCookie(rw, req)
if signOutAllSessions {
session, errAuthSession := p.getAuthenticatedSession(rw, req)
if errAuthSession != nil {
logger.Errorf("Error clearing all sessions cookie: %v", errAuthSession)
} else {
clearAllError := p.ClearAllSessions(req, session)
if clearAllError != nil {
logger.Errorf("Error clearing session cookie: %v", clearAllError)
}
}
}
if err != nil {
logger.Errorf("Error clearing session cookie: %v", err)
p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error())
Expand Down
1 change: 1 addition & 0 deletions pkg/apis/sessions/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type SessionStore interface {
Save(rw http.ResponseWriter, req *http.Request, s *SessionState) error
Load(req *http.Request) (*SessionState, error)
Clear(rw http.ResponseWriter, req *http.Request) error
ClearAllUserSessions(req *http.Request, session *SessionState) error
VerifyConnection(ctx context.Context) error
}

Expand Down
8 changes: 8 additions & 0 deletions pkg/encryption/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ func GenerateRandomASCIIString(length int) (string, error) {
return string(b), nil
}

// Encrypts a string with a secret using HMAC-SHA256 and returns a base64-encoded string.
func EncryptStringWithSecret(input, secret string) string {
mac := hmac.New(sha256.New, []byte(secret))
mac.Write([]byte(input))
sum := mac.Sum(nil)
return base64.RawURLEncoding.EncodeToString(sum)
}

func GenerateCodeChallenge(method, codeVerifier string) (string, error) {
switch method {
case CodeChallengeMethodPlain:
Expand Down
8 changes: 8 additions & 0 deletions pkg/encryption/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ func TestSecretBytesNonBase64(t *testing.T) {
assert.Equal(t, 32, len(sb32))
}

func TestEncryptStringWithSecret(t *testing.T) {
secret := "my-secret"
input := "my-input"

result := EncryptStringWithSecret(input, secret)
assert.Equal(t, result, "7DlH3g1Io9AmyD8tVPEPHdpH9N4jsO07mNCkfNcfW2A")
}

func TestSignAndValidate(t *testing.T) {
seed := "0123456789abcdef"
key := "cookie-name"
Expand Down
14 changes: 11 additions & 3 deletions pkg/middleware/stored_session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -763,9 +763,10 @@ var _ = Describe("Stored Session Suite", func() {
})

type fakeSessionStore struct {
SaveFunc func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error
LoadFunc func(req *http.Request) (*sessionsapi.SessionState, error)
ClearFunc func(rw http.ResponseWriter, req *http.Request) error
SaveFunc func(http.ResponseWriter, *http.Request, *sessionsapi.SessionState) error
LoadFunc func(req *http.Request) (*sessionsapi.SessionState, error)
ClearFunc func(rw http.ResponseWriter, req *http.Request) error
ClearAllUserSessionsFunc func(req *http.Request, session *sessionsapi.SessionState) error
}

func (f *fakeSessionStore) Save(rw http.ResponseWriter, req *http.Request, s *sessionsapi.SessionState) error {
Expand All @@ -788,6 +789,13 @@ func (f *fakeSessionStore) Clear(rw http.ResponseWriter, req *http.Request) erro
return nil
}

func (f *fakeSessionStore) ClearAllUserSessions(req *http.Request, session *sessionsapi.SessionState) error {
if f.ClearAllUserSessionsFunc != nil {
return f.ClearAllUserSessionsFunc(req, session)
}
return nil
}

func (f *fakeSessionStore) VerifyConnection(_ context.Context) error {
return nil
}
5 changes: 5 additions & 0 deletions pkg/sessions/cookie/session_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ type SessionStore struct {
Minimal bool
}

// ClearAll implements sessions.SessionStore.
func (s *SessionStore) ClearAllUserSessions(_ *http.Request, _ *sessions.SessionState) error {
return fmt.Errorf("ClearAllUserSessions is only supported by redis store")
}

// Save takes a sessions.SessionState and stores the information from it
// within Cookies set on the HTTP response writer
func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
Expand Down
2 changes: 2 additions & 0 deletions pkg/sessions/persistence/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ import (
type Store interface {
Save(context.Context, string, []byte, time.Duration) error
Load(context.Context, string) ([]byte, error)
LoadList(ctx context.Context, key string) ([]string, error)
Clear(context.Context, string) error
Lock(key string) sessions.Lock
RPush(context.Context, string, string, time.Duration) error
VerifyConnection(context.Context) error
}
36 changes: 33 additions & 3 deletions pkg/sessions/persistence/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
)

// Manager wraps a Store and handles the implementation details of the
Expand Down Expand Up @@ -42,9 +43,14 @@ func (m *Manager) Save(rw http.ResponseWriter, req *http.Request, s *sessions.Se
}
}

err = tckt.saveSession(s, func(key string, val []byte, exp time.Duration) error {
return m.Store.Save(req.Context(), key, val, exp)
})
err = tckt.saveSession(
s,
func(key string, val []byte, exp time.Duration) error {
return m.Store.Save(req.Context(), key, val, exp)
},
func(key string, val string, exp time.Duration) error {
return m.Store.RPush(req.Context(), key, val, exp)
})
if err != nil {
return err
}
Expand All @@ -68,6 +74,30 @@ func (m *Manager) Load(req *http.Request) (*sessions.SessionState, error) {
)
}

// ClearAll implements sessions.SessionStore.
func (m *Manager) ClearAllUserSessions(req *http.Request, session *sessions.SessionState) error {
ticket, _ := decodeTicketFromRequest(req, m.Options)
sessionKey := encryption.EncryptStringWithSecret(session.User+session.Email, ticket.options.Secret)
keys, err := m.Store.LoadList(req.Context(), sessionKey)
if err != nil {
return fmt.Errorf("error decoding ticket to clear session: %v", err)
}

for _, key := range keys {
err = m.Store.Clear(req.Context(), key)
if err != nil {
return fmt.Errorf("error clearing session for key: %v", err)
}
}

err = m.Store.Clear(req.Context(), sessionKey)
if err != nil {
return fmt.Errorf("error clearing sessions keys: %v", err)
}

return err
}

// Clear clears any saved session information for a given ticket cookie.
// Then it clears all session data for that ticket in the Store.
func (m *Manager) Clear(rw http.ResponseWriter, req *http.Request) error {
Expand Down
9 changes: 8 additions & 1 deletion pkg/sessions/persistence/ticket.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ import (
// a key string, value []byte & (optional) expiration time.Duration
type saveFunc func(string, []byte, time.Duration) error

// saveUserStateFunc performs a persistent store's save functionality using
// a key string, value []byte & (optional) expiration time.Duration
type saverUserMapSessionFunc func(string, string, time.Duration) error

// loadFunc performs a load from a persistent store using a
// string key and returning the stored value as []byte
type loadFunc func(string) ([]byte, error)
Expand Down Expand Up @@ -157,7 +161,7 @@ func decodeTicketFromRequest(req *http.Request, cookieOpts *options.Cookie) (*ti

// saveSession encodes the SessionState with the ticket's secret and persists
// it to disk via the passed saveFunc.
func (t *ticket) saveSession(s *sessions.SessionState, saver saveFunc) error {
func (t *ticket) saveSession(s *sessions.SessionState, saver saveFunc, saverUserMapSession saverUserMapSessionFunc) error {
c, err := t.makeCipher()
if err != nil {
return err
Expand All @@ -166,6 +170,9 @@ func (t *ticket) saveSession(s *sessions.SessionState, saver saveFunc) error {
if err != nil {
return fmt.Errorf("failed to encode the session state with the ticket: %v", err)
}

encodedUserState := encryption.EncryptStringWithSecret(s.User+s.Email, t.options.Secret)
saverUserMapSession(encodedUserState, t.id, 2*time.Hour)
return saver(t.id, ciphertext, t.options.Expire)
}

Expand Down
37 changes: 33 additions & 4 deletions pkg/sessions/persistence/ticket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,26 @@ var _ = Describe("Session Ticket Tests", func() {

ss := &sessions.SessionState{User: "foobar"}
store := map[string][]byte{}
err = t.saveSession(ss, func(k string, v []byte, e time.Duration) error {
store[k] = v
return nil
})
storeUserSessionList := map[string][]string{}
storedUserSessionListExpected := map[string][]string{
"16-axDAZ63SxeHvCLMjoF5EEX0ipSzNNqxpUITxPxgk": {t.id},
}
err = t.saveSession(
ss,
func(k string, v []byte, e time.Duration) error {
store[k] = v
return nil
},
func(key string, value string, d time.Duration) error {
storeUserSessionList[key] = append(storeUserSessionList[key], value)
return nil
})
Expect(err).ToNot(HaveOccurred())

stored, err := sessions.DecodeSessionState(store[t.id], c, false)
Expect(err).ToNot(HaveOccurred())
Expect(stored).To(Equal(ss))
Expect(storeUserSessionList).To(Equal(storedUserSessionListExpected))
})

It("errors when the saveFunc errors", func() {
Expand All @@ -89,9 +100,27 @@ var _ = Describe("Session Ticket Tests", func() {
&sessions.SessionState{User: "foobar"},
func(k string, v []byte, e time.Duration) error {
return errors.New("save error")
},
func(key string, value string, d time.Duration) error {
return nil
})
Expect(err).To(MatchError(errors.New("save error")))
})

It("should not return error when the saverUserMapSession errors", func() {
t, err := newTicket(&options.Cookie{Name: "dummy"})
Expect(err).ToNot(HaveOccurred())

err = t.saveSession(
&sessions.SessionState{User: "foobar"},
func(k string, v []byte, e time.Duration) error {
return nil
},
func(key string, value string, d time.Duration) error {
return errors.New("save user session error")
})
Expect(err).To(BeNil())
})
})

Context("loadSession", func() {
Expand Down
29 changes: 29 additions & 0 deletions pkg/sessions/redis/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ import (
type Client interface {
Get(ctx context.Context, key string) ([]byte, error)
Lock(key string) sessions.Lock
Expire(ctx context.Context, key string, expiration time.Duration) error
RPush(ctx context.Context, key string, value string) error
LRange(ctx context.Context, key string) ([]string, error)
Set(ctx context.Context, key string, value []byte, expiration time.Duration) error
Del(ctx context.Context, key string) error
Ping(ctx context.Context) error
Expand All @@ -29,6 +32,18 @@ func newClient(c *redis.Client) Client {
}
}

func (c *client) Expire(ctx context.Context, key string, expiration time.Duration) error {
return c.Client.Expire(ctx, key, expiration).Err()
}

func (c *client) LRange(ctx context.Context, key string) ([]string, error) {
return c.Client.LRange(ctx, key, 0, -1).Result()
}

func (c *client) RPush(ctx context.Context, key string, value string) error {
return c.Client.RPush(ctx, key, value).Err()
}

func (c *client) Get(ctx context.Context, key string) ([]byte, error) {
return c.Client.Get(ctx, key).Bytes()
}
Expand Down Expand Up @@ -61,6 +76,12 @@ func newClusterClient(c *redis.ClusterClient) Client {
}
}

// Expire implements Client.
// Subtle: this method shadows the method (*ClusterClient).Expire of clusterClient.ClusterClient.
func (c *clusterClient) Expire(ctx context.Context, key string, expiration time.Duration) error {
return c.ClusterClient.Expire(ctx, key, expiration).Err()
}

func (c *clusterClient) Get(ctx context.Context, key string) ([]byte, error) {
return c.ClusterClient.Get(ctx, key).Bytes()
}
Expand All @@ -69,6 +90,14 @@ func (c *clusterClient) Set(ctx context.Context, key string, value []byte, expir
return c.ClusterClient.Set(ctx, key, value, expiration).Err()
}

func (c *clusterClient) RPush(ctx context.Context, key string, value string) error {
return c.ClusterClient.RPush(ctx, key, value).Err()
}

func (c *clusterClient) LRange(ctx context.Context, key string) ([]string, error) {
return c.ClusterClient.LRange(ctx, key, 0, -1).Result()
}

func (c *clusterClient) Del(ctx context.Context, key string) error {
return c.ClusterClient.Del(ctx, key).Err()
}
Expand Down
25 changes: 25 additions & 0 deletions pkg/sessions/redis/redis_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,31 @@ func (store *SessionStore) Save(ctx context.Context, key string, value []byte, e
return nil
}

// Save takes a sessions.SessionState and stores the information from it
// to redis, and adds a new persistence cookie on the HTTP response writer
func (store *SessionStore) RPush(ctx context.Context, key string, value string, exp time.Duration) error {
err := store.Client.RPush(ctx, key, value)
if err != nil {
return fmt.Errorf("error appending redis session: %v", err)
}

if exp > 0 {
if err := store.Client.Expire(ctx, key, exp); err != nil {
return fmt.Errorf("error settings expiration time on appending redis session: %v", err)
}
}
return nil
}

// LoadList reads a list of strings from Redis at the given key and returns.
func (store *SessionStore) LoadList(ctx context.Context, key string) ([]string, error) {
values, err := store.Client.LRange(ctx, key)
if err != nil {
return nil, fmt.Errorf("error loading redis list: %v", err)
}
return values, nil
}

// Load reads sessions.SessionState information from a persistence
// cookie within the HTTP request object
func (store *SessionStore) Load(ctx context.Context, key string) ([]byte, error) {
Expand Down
Loading
Loading