From 6a142cf5a5299ed2ff7b7f5a64690b17c3557fa2 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Wed, 4 Jan 2023 11:00:33 +0100 Subject: [PATCH] refactor(handler): use session cookie for frontchannel logout if available, clean up logout handlers --- pkg/handler/api/logout/logout.go | 33 ++++++++++--------- .../logoutfrontchannel/logoutfrontchannel.go | 29 ++++++++++++---- pkg/session/handler.go | 15 ++------- 3 files changed, 42 insertions(+), 35 deletions(-) diff --git a/pkg/handler/api/logout/logout.go b/pkg/handler/api/logout/logout.go index 6cb0944..5752e09 100644 --- a/pkg/handler/api/logout/logout.go +++ b/pkg/handler/api/logout/logout.go @@ -5,8 +5,6 @@ import ( "fmt" "net/http" - log "github.com/sirupsen/logrus" - "github.com/nais/wonderwall/pkg/cookie" errorhandler "github.com/nais/wonderwall/pkg/handler/error" "github.com/nais/wonderwall/pkg/loginstatus" @@ -37,23 +35,26 @@ func Handler(src Source, w http.ResponseWriter, r *http.Request, opts Options) { return } - idToken := "" + var idToken string - sessionData, err := src.GetSessions().Get(r) - if err == nil && sessionData != nil { - idToken = sessionData.IDToken + sessions := src.GetSessions() - err = src.GetSessions().DestroyForID(r, sessionData.ExternalSessionID) - if err != nil && !errors.Is(err, session.ErrKeyNotFound) { - src.GetErrorHandler().InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err)) - return + key, err := sessions.GetKey(r) + if err == nil { + sessionData, err := sessions.GetForKey(r, key) + if err == nil && sessionData != nil { + idToken = sessionData.IDToken + + err = sessions.Destroy(r, key) + if err != nil && !errors.Is(err, session.ErrKeyNotFound) { + src.GetErrorHandler().InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err)) + return + } + + logger.WithField("jti", sessionData.IDTokenJwtID). + Info("logout: successful local logout") + metrics.ObserveLogout(metrics.LogoutOperationLocal) } - - fields := log.Fields{ - "jti": sessionData.IDTokenJwtID, - } - logger.WithFields(fields).Info("logout: successful local logout") - metrics.ObserveLogout(metrics.LogoutOperationLocal) } cookie.Clear(w, cookie.Session, src.GetCookieOptsPathAware(r)) diff --git a/pkg/handler/api/logoutfrontchannel/logoutfrontchannel.go b/pkg/handler/api/logoutfrontchannel/logoutfrontchannel.go index 9019d73..c5859dc 100644 --- a/pkg/handler/api/logoutfrontchannel/logoutfrontchannel.go +++ b/pkg/handler/api/logoutfrontchannel/logoutfrontchannel.go @@ -1,6 +1,7 @@ package logoutfrontchannel import ( + "fmt" "net/http" "github.com/nais/wonderwall/pkg/cookie" @@ -29,22 +30,23 @@ func Handler(src Source, w http.ResponseWriter, r *http.Request) { src.GetLoginstatus().ClearCookie(w, src.GetCookieOptions()) } - logoutFrontchannel := src.GetClient().LogoutFrontchannel(r) - if logoutFrontchannel.MissingSidParameter() { - logger.Debug("front-channel logout: sid parameter not set in request; ignoring") + sessions := src.GetSessions() + client := src.GetClient() + key, err := getSessionKey(r, sessions, client) + if err != nil { + logger.Debugf("front-channel logout: getting session key: %+v; ignoring", err) w.WriteHeader(http.StatusAccepted) return } - sid := logoutFrontchannel.Sid() - sessionData, err := src.GetSessions().GetForID(r, sid) + sessionData, err := sessions.GetForKey(r, key) if err != nil { logger.Debugf("front-channel logout: could not get session (user might already be logged out): %+v", err) w.WriteHeader(http.StatusAccepted) return } - err = src.GetSessions().DestroyForID(r, sid) + err = sessions.Destroy(r, key) if err != nil { logger.Warnf("front-channel logout: destroying session: %+v", err) w.WriteHeader(http.StatusAccepted) @@ -57,3 +59,18 @@ func Handler(src Source, w http.ResponseWriter, r *http.Request) { metrics.ObserveLogout(metrics.LogoutOperationFrontChannel) w.WriteHeader(http.StatusOK) } + +func getSessionKey(r *http.Request, sessions *session.Handler, client *openidclient.Client) (string, error) { + logoutFrontchannel := client.LogoutFrontchannel(r) + + if logoutFrontchannel.MissingSidParameter() { + key, err := sessions.GetKey(r) + if err != nil { + return key, nil + } + return "", fmt.Errorf("neither sid parameter nor session key found in request: %w", err) + } + + sid := logoutFrontchannel.Sid() + return sessions.Key(sid), nil +} diff --git a/pkg/session/handler.go b/pkg/session/handler.go index 68b1aab..da9fe6c 100644 --- a/pkg/session/handler.go +++ b/pkg/session/handler.go @@ -91,13 +91,8 @@ func (h *Handler) Create(r *http.Request, tokens *openid.Tokens, sessionLifetime return key, nil } -// DestroyForID destroys a session for a given session ID. Note that a session ID is not equal to a session Key. -func (h *Handler) DestroyForID(r *http.Request, id string) error { - key := h.Key(id) - return h.destroyForKey(r, key) -} - -func (h *Handler) destroyForKey(r *http.Request, key string) error { +// Destroy destroys a session for a given session Key. +func (h *Handler) Destroy(r *http.Request, key string) error { retryable := func(ctx context.Context) error { err := h.store.Delete(r.Context(), key) if err == nil { @@ -150,12 +145,6 @@ func (h *Handler) GetAccessToken(r *http.Request) (string, error) { return sessionData.AccessToken, nil } -// GetForID returns the session data for a given session ID. -func (h *Handler) GetForID(r *http.Request, id string) (*Data, error) { - key := h.Key(id) - return h.GetForKey(r, key) -} - // GetForKey returns the session data for a given session Key. func (h *Handler) GetForKey(r *http.Request, key string) (*Data, error) { var encryptedSessionData *EncryptedData