From 03eec9d2b8b601b29b5f76d5bf936829182a2596 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Fri, 1 Oct 2021 09:35:21 +0200 Subject: [PATCH] refactor: robustify logout routes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Morten Lied Johansen Co-authored-by: Sindre Rødseth Hansen --- pkg/router/router.go | 59 +++++++++++++++---------------------------- pkg/router/session.go | 27 ++++++++++++++++++-- pkg/session/redis.go | 15 ++++------- 3 files changed, 51 insertions(+), 50 deletions(-) diff --git a/pkg/router/router.go b/pkg/router/router.go index 1d8a463..aa03b25 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -326,20 +326,6 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { // Logout triggers self-initiated for the current user func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { - var idToken string - - sess, err := h.getSessionFromCookie(r) - if err == nil && sess != nil && sess.OAuth2Token != nil { - idToken = sess.IDTokenSerialized - err = h.Sessions.Delete(r.Context(), h.localSessionID(sess.ExternalSessionID)) - if err != nil { - log.Error(err) - w.WriteHeader(http.StatusInternalServerError) - return - } - h.deleteCookie(w, h.GetSessionCookieName()) - } - u, err := url.Parse(h.Config.WellKnown.EndSessionEndpoint) if err != nil { log.Error(err) @@ -347,6 +333,21 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { return } + var idToken string + + sess, err := h.getSessionFromCookie(r) + if err == nil && sess != nil && sess.OAuth2Token != nil { + idToken = sess.IDTokenSerialized + err = h.destroySession(r, h.localSessionID(sess.ExternalSessionID)) + if err != nil { + log.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + } + + h.deleteCookie(w, h.GetSessionCookieName()) + v := u.Query() v.Add("post_logout_redirect_uri", PostLogoutRedirectURI(r, h.Config.PostLogoutRedirectURI)) @@ -362,42 +363,24 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { params := r.URL.Query() - iss := params.Get("iss") sid := params.Get("sid") - if len(sid) == 0 || len(iss) == 0 { + if len(sid) == 0 { + log.Error("sid not set for front-channel logout") w.WriteHeader(http.StatusBadRequest) return } sessionID := h.localSessionID(sid) - // From here on, check that 'iss' from request matches data found in access token. - accessToken, err := h.getAccessTokenFromSession(r, sessionID) - if accessToken == nil { - // Can't remove session because it doesn't exist. Maybe it was garbage collected. - // We regard this as a redundant logout and return 200 OK. - return - } + err := h.destroySession(r, sessionID) if err != nil { log.Error(err) - w.WriteHeader(http.StatusInternalServerError) - return + // Session is already destroyed at the OP and is highly unlikely to be used again. } - err = jwt.Validate(accessToken, jwt.WithClaimValue("iss", iss)) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - // All verified; delete session. - err = h.Sessions.Delete(r.Context(), sessionID) - if err != nil { - log.Error(err) - w.WriteHeader(http.StatusInternalServerError) - return - } + // Unconditionally destroy all local references to the session. + h.deleteCookie(w, h.GetSessionCookieName()) } func New(handler *Handler, prefixes []string) chi.Router { diff --git a/pkg/router/session.go b/pkg/router/session.go index 6a5769d..2553449 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -1,7 +1,9 @@ package router import ( + "errors" "fmt" + "github.com/go-redis/redis/v8" "github.com/lestrrat-go/jwx/jwt" "github.com/nais/wonderwall/pkg/session" "github.com/nais/wonderwall/pkg/token" @@ -27,6 +29,10 @@ func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) { encryptedSessionData, err := h.Sessions.Read(r.Context(), sessionID) if err != nil { + if errors.Is(err, redis.Nil) { + // TODO: attempt to fetch encrypted data from fallback session cookie (if set) + } + return nil, fmt.Errorf("reading session from store: %w", err) } @@ -81,6 +87,8 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime) if err != nil { + // TODO: fallback to writing encrypted session data to cookie + return fmt.Errorf("writing session to store: %w", err) } @@ -90,8 +98,13 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external func (h *Handler) getAccessTokenFromSession(r *http.Request, sessionID string) (jwt.Token, error) { encryptedSession, err := h.Sessions.Read(r.Context(), sessionID) if err != nil { - // Session not found; ignoring - return nil, nil + if errors.Is(err, redis.Nil) { + // Session not found; ignoring + return nil, nil + } + + // TODO: fetch from fallback session cookie (if set) + return nil, fmt.Errorf("fetching session from store: %w", err) } sessionData, err := encryptedSession.Decrypt(h.Crypter) @@ -107,3 +120,13 @@ func (h *Handler) getAccessTokenFromSession(r *http.Request, sessionID string) ( return accessToken, nil } + +func (h *Handler) destroySession(r *http.Request, sessionID string) error { + err := h.Sessions.Delete(r.Context(), sessionID) + if err != nil { + return fmt.Errorf("deleting session from store: %w", err) + } + + // TODO: delete fallback session cookie (if set) + return nil +} diff --git a/pkg/session/redis.go b/pkg/session/redis.go index ac672ef..c6ae959 100644 --- a/pkg/session/redis.go +++ b/pkg/session/redis.go @@ -8,24 +8,21 @@ import ( ) type redisSessionStore struct { - client redis.Cmdable + client redis.Cmdable } var _ Store = &redisSessionStore{} func NewRedis(client redis.Cmdable) Store { return &redisSessionStore{ - client: client, + client: client, } } func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedData, error) { encryptedData := &EncryptedData{} err := metrics.ObserveRedisLatency("Read", func() error { - var err error - status := s.client.Get(ctx, key) - err = status.Scan(encryptedData) - return err + return s.client.Get(ctx, key).Scan(encryptedData) }) if err != nil { return nil, err @@ -36,14 +33,12 @@ func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedDat func (s *redisSessionStore) Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error { return metrics.ObserveRedisLatency("Write", func() error { - status := s.client.Set(ctx, key, value, expiration) - return status.Err() + return s.client.Set(ctx, key, value, expiration).Err() }) } func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error { return metrics.ObserveRedisLatency("Delete", func() error { - status := s.client.Del(ctx, keys...) - return status.Err() + return s.client.Del(ctx, keys...).Err() }) }