From a23595b9b422a20aab1df01f0425d3aec36dd3fe Mon Sep 17 00:00:00 2001 From: ybelMekk Date: Sun, 23 Jan 2022 23:14:25 +0100 Subject: [PATCH] add: handle trigger of logout for third-party and `session_state` --- pkg/router/handler_frontchannellogout.go | 25 +++++++++++++++++------- pkg/router/router_test.go | 4 ++-- pkg/router/session.go | 10 +++++----- 3 files changed, 25 insertions(+), 14 deletions(-) diff --git a/pkg/router/handler_frontchannellogout.go b/pkg/router/handler_frontchannellogout.go index 1246248..e805e5f 100644 --- a/pkg/router/handler_frontchannellogout.go +++ b/pkg/router/handler_frontchannellogout.go @@ -1,27 +1,26 @@ package router import ( - "net/http" - log "github.com/sirupsen/logrus" + "net/http" ) // FrontChannelLogout triggers logout triggered by a third-party. func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { - params := r.URL.Query() - sid := params.Get("sid") + sessionParamKeys := []string{"sid", "session_state"} + externalSessionID := extractExternalSessionID(r, sessionParamKeys) // Unconditionally destroy all local references to the session. h.deleteCookie(w, SessionCookieName, h.Cookies) - if len(sid) == 0 { - log.Info("sid parameter not set in request; ignoring") + if len(externalSessionID) == 0 { + log.Infof("any of parameters %q not set in request; ignoring", sessionParamKeys) h.DeleteSessionFallback(w, r) w.WriteHeader(http.StatusOK) return } - sessionID := h.localSessionID(sid) + sessionID := h.localSessionID(externalSessionID) err := h.destroySession(w, r, sessionID) if err != nil { @@ -31,3 +30,15 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } + +func extractExternalSessionID(r *http.Request, paramKeys []string) string { + params := r.URL.Query() + var sessionId = "" + for _, k := range paramKeys { + sessionId = params.Get(k) + if len(sessionId) == 0 { + continue + } + } + return sessionId +} diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 24db283..ef3a4d2 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -321,7 +321,7 @@ func TestHandler_FrontChannelLogoutWithCheckSessionIframe(t *testing.T) { ciphertext, err := base64.StdEncoding.DecodeString(sessionCookie.Value) assert.NoError(t, err) - sid, err := h.Crypter.Decrypt(ciphertext) + sessionState, err := h.Crypter.Decrypt(ciphertext) assert.NoError(t, err) frontchannelLogoutURL, err := url.Parse(server.URL) @@ -330,7 +330,7 @@ func TestHandler_FrontChannelLogoutWithCheckSessionIframe(t *testing.T) { frontchannelLogoutURL.Path = "/oauth2/logout/frontchannel" values := url.Values{} - values.Add("sid", string(sid)) + values.Add("session_state", string(sessionState)) values.Add("iss", idp.GetOpenIDConfiguration().Issuer) frontchannelLogoutURL.RawQuery = values.Encode() diff --git a/pkg/router/session.go b/pkg/router/session.go index 71d5e20..14f2353 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -15,13 +15,13 @@ import ( "github.com/nais/wonderwall/pkg/session" ) -// localSessionID prefixes the given `sid` with the given client ID to prevent key collisions. -// `sid` is a key that refers to the user's unique SSO session at the Identity Provider, and the same key is present +// localSessionID prefixes the given `sid` or `session_state` with the given client ID to prevent key collisions. +// `sid` or `session_state` is a key that refers to the user's unique SSO session at the Identity Provider, and the same key is present // in all tokens acquired by any Relying Party (such as Wonderwall) during that session. -// Thus, we cannot assume that the value of `sid` to uniquely identify the pair of (user, application session) +// Thus, we cannot assume that the value of `sid` or `session_state` to uniquely identify the pair of (user, application session) // if using a shared session store. -func (h *Handler) localSessionID(sid string) string { - return fmt.Sprintf("%s:%s:%s", h.Config.OpenID.Provider, h.Provider.GetClientConfiguration().GetClientID(), sid) +func (h *Handler) localSessionID(externalSessionID string) string { + return fmt.Sprintf("%s:%s:%s", h.Config.OpenID.Provider, h.Provider.GetClientConfiguration().GetClientID(), externalSessionID) } func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (*session.Data, error) {