From b60db493ace515b93a6cfd7ab774d1e7a2baf4a2 Mon Sep 17 00:00:00 2001 From: Morten Lied Johansen Date: Wed, 29 Sep 2021 10:20:11 +0200 Subject: [PATCH] Add ClientID to cookie names MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Sindre Rødseth Hansen --- pkg/router/cookies.go | 10 +++++++++- pkg/router/router.go | 13 ++++++------- pkg/router/router_test.go | 6 +++--- pkg/router/session.go | 2 +- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/pkg/router/cookies.go b/pkg/router/cookies.go index 9d3b583..65f9442 100644 --- a/pkg/router/cookies.go +++ b/pkg/router/cookies.go @@ -21,8 +21,16 @@ type CallbackParams struct { Referer string `json:"referer"` } +func (h *Handler) getCallbackCookieName() string { + return fmt.Sprintf(CallbackCookieNameTemplate, h.Config.ClientID) +} + +func (h *Handler) GetSessionCookieName() string { + return fmt.Sprintf(SessionCookieNameTemplate, h.Config.ClientID) +} + func (h *Handler) getCallbackParams(r *http.Request) (*CallbackParams, error) { - callbackCookieString, err := h.getEncryptedCookie(r, CallbackCookieName) + callbackCookieString, err := h.getEncryptedCookie(r, h.getCallbackCookieName()) if err != nil { return nil, err } diff --git a/pkg/router/router.go b/pkg/router/router.go index 265d506..ce8f63d 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -31,10 +31,9 @@ import ( ) const ( - SessionCookieName = "io.nais.wonderwall.session" - - LoginCookieLifetime = 2 * time.Minute - CallbackCookieName = "io.nais.wonderwall.callback" + LoginCookieLifetime = 2 * time.Minute + SessionCookieNameTemplate = "io.nais.wonderwall.%s.session" + CallbackCookieNameTemplate = "io.nais.wonderwall.%s.callback" RedirectURLParameter = "redirect" SecurityLevelURLParameter = "level" @@ -230,7 +229,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { return } - err = h.setEncryptedCookie(w, CallbackCookieName, string(jsonString), LoginCookieLifetime) + err = h.setEncryptedCookie(w, h.getCallbackCookieName(), string(jsonString), LoginCookieLifetime) if err != nil { log.Error(err) w.WriteHeader(http.StatusInternalServerError) @@ -317,7 +316,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { } sessionID := h.localSessionID(externalSessionID) - err = h.setEncryptedCookie(w, SessionCookieName, sessionID, h.Config.SessionMaxLifetime) + err = h.setEncryptedCookie(w, h.GetSessionCookieName(), sessionID, h.Config.SessionMaxLifetime) if err != nil { log.Error(err) w.WriteHeader(http.StatusInternalServerError) @@ -405,7 +404,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) return } - h.deleteCookie(w, SessionCookieName) + h.deleteCookie(w, h.GetSessionCookieName()) } u, err := url.Parse(h.Config.WellKnown.EndSessionEndpoint) diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 2f5f930..303ebce 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -248,7 +248,7 @@ func TestHandler_Callback_and_Logout(t *testing.T) { cookies := client.Jar.Cookies(callbackURL) var sessionCookie *http.Cookie for _, cookie := range cookies { - if cookie.Name == router.SessionCookieName { + if cookie.Name == h.GetSessionCookieName() { sessionCookie = cookie } } @@ -262,7 +262,7 @@ func TestHandler_Callback_and_Logout(t *testing.T) { cookies = client.Jar.Cookies(callbackURL) for _, cookie := range cookies { - if cookie.Name == router.SessionCookieName { + if cookie.Name == h.GetSessionCookieName() { sessionCookie = cookie } } @@ -342,7 +342,7 @@ func TestHandler_FrontChannelLogout(t *testing.T) { cookies := client.Jar.Cookies(callbackURL) var sessionCookie *http.Cookie for _, cookie := range cookies { - if cookie.Name == router.SessionCookieName { + if cookie.Name == h.GetSessionCookieName() { sessionCookie = cookie } } diff --git a/pkg/router/session.go b/pkg/router/session.go index adbc881..ca2ed1e 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -7,7 +7,7 @@ import ( ) func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) { - sessionID, err := h.getEncryptedCookie(r, SessionCookieName) + sessionID, err := h.getEncryptedCookie(r, h.GetSessionCookieName()) if err != nil { return nil, fmt.Errorf("no session cookie: %w", err) }