From b2e89f32fa3a42f73917c05510281e4e8457f82d Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Thu, 30 Sep 2021 10:05:45 +0200 Subject: [PATCH] refactor: ensure cookies are properly disposed of MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Sindre Rødseth Hansen --- pkg/router/cookies.go | 17 ++++++--- pkg/router/router_test.go | 77 ++++++++++++++++++++++----------------- 2 files changed, 54 insertions(+), 40 deletions(-) diff --git a/pkg/router/cookies.go b/pkg/router/cookies.go index 87d890b..5700117 100644 --- a/pkg/router/cookies.go +++ b/pkg/router/cookies.go @@ -75,13 +75,14 @@ func (h *Handler) setEncryptedCookie(w http.ResponseWriter, key string, plaintex } http.SetCookie(w, &http.Cookie{ - Name: key, - Value: base64.StdEncoding.EncodeToString(ciphertext), Expires: time.Now().Add(expiresIn), - Secure: h.SecureCookies, HttpOnly: true, - SameSite: http.SameSiteLaxMode, + MaxAge: int(expiresIn.Seconds()), + Name: key, Path: "/", + SameSite: http.SameSiteLaxMode, + Secure: h.SecureCookies, + Value: base64.StdEncoding.EncodeToString(ciphertext), }) return nil @@ -107,10 +108,14 @@ func (h *Handler) getEncryptedCookie(r *http.Request, key string) (string, error } func (h *Handler) deleteCookie(w http.ResponseWriter, key string) { + expires := time.Now().Add(-7 * 24 * time.Hour) http.SetCookie(w, &http.Cookie{ + Expires: expires, + HttpOnly: true, + MaxAge: -1, Name: key, - Secure: h.SecureCookies, - SameSite: http.SameSiteLaxMode, Path: "/", + SameSite: http.SameSiteLaxMode, + Secure: h.SecureCookies, }) } diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 6ac9b14..97b5902 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -153,8 +153,12 @@ func TestHandler_Login(t *testing.T) { prefixes := config.ParseIngresses([]string{""}) r := router.New(h, prefixes) + jar, err := cookiejar.New(nil) + assert.NoError(t, err) + server := httptest.NewServer(r) client := server.Client() + client.Jar = jar client.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } @@ -164,10 +168,17 @@ func TestHandler_Login(t *testing.T) { h.Config.WellKnown.AuthorizationEndpoint = idpserver.URL + "/authorize" - req, err := client.Get(server.URL + "/oauth2/login") + loginURL, err := url.Parse(server.URL + "/oauth2/login") + assert.NoError(t, err) + + req, err := client.Get(loginURL.String()) assert.NoError(t, err) defer req.Body.Close() + cookies := client.Jar.Cookies(loginURL) + loginCookie := getCookieFromJar(h.GetLoginCookieName(), cookies) + assert.NotNil(t, loginCookie) + location := req.Header.Get("location") u, err := url.Parse(location) assert.NoError(t, err) @@ -222,10 +233,18 @@ func TestHandler_Callback_and_Logout(t *testing.T) { } // First, run /oauth2/login to set cookies - req, err := client.Get(server.URL + "/oauth2/login") + loginURL, err := url.Parse(server.URL + "/oauth2/login") + req, err := client.Get(loginURL.String()) assert.NoError(t, err) defer req.Body.Close() + cookies := client.Jar.Cookies(loginURL) + sessionCookie := getCookieFromJar(h.GetSessionCookieName(), cookies) + loginCookie := getCookieFromJar(h.GetLoginCookieName(), cookies) + + assert.Nil(t, sessionCookie) + assert.NotNil(t, loginCookie) + // Get authorization URL location := req.Header.Get("location") u, err := url.Parse(location) @@ -245,40 +264,25 @@ func TestHandler_Callback_and_Logout(t *testing.T) { req, err = client.Get(callbackURL.String()) assert.NoError(t, err) - cookies := client.Jar.Cookies(callbackURL) - var sessionCookie *http.Cookie - var loginCookie *http.Cookie - for _, cookie := range cookies { - if cookie.Name == h.GetSessionCookieName() { - sessionCookie = cookie - } - - if cookie.Name == h.GetLoginCookieName() { - loginCookie = cookie - } - } + cookies = client.Jar.Cookies(callbackURL) + sessionCookie = getCookieFromJar(h.GetSessionCookieName(), cookies) + loginCookie = getCookieFromJar(h.GetLoginCookieName(), cookies) assert.NotNil(t, sessionCookie) - - assert.NotNil(t, loginCookie) - assert.Empty(t, loginCookie.Value) - assert.True(t, loginCookie.Expires.Before(time.Now())) + assert.Nil(t, loginCookie) // Request self-initiated logout - req, err = client.Get(server.URL + "/oauth2/logout") + logoutURL, err := url.Parse(server.URL + "/oauth2/logout") + assert.NoError(t, err) + + req, err = client.Get(logoutURL.String()) assert.NoError(t, err) defer req.Body.Close() - cookies = client.Jar.Cookies(callbackURL) - for _, cookie := range cookies { - if cookie.Name == h.GetSessionCookieName() { - sessionCookie = cookie - } - } + cookies = client.Jar.Cookies(logoutURL) + sessionCookie = getCookieFromJar(h.GetSessionCookieName(), cookies) - assert.NotNil(t, sessionCookie) - assert.Empty(t, sessionCookie.Value) - assert.True(t, sessionCookie.Expires.Before(time.Now())) + assert.Nil(t, sessionCookie) // Get endsession endpoint after local logout location = req.Header.Get("location") @@ -349,12 +353,7 @@ func TestHandler_FrontChannelLogout(t *testing.T) { assert.NoError(t, err) cookies := client.Jar.Cookies(callbackURL) - var sessionCookie *http.Cookie - for _, cookie := range cookies { - if cookie.Name == h.GetSessionCookieName() { - sessionCookie = cookie - } - } + sessionCookie := getCookieFromJar(h.GetSessionCookieName(), cookies) assert.NotNil(t, sessionCookie) @@ -379,3 +378,13 @@ func TestHandler_FrontChannelLogout(t *testing.T) { assert.NoError(t, err) defer req.Body.Close() } + +func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie { + for _, cookie := range cookies { + if cookie.Name == name { + return cookie + } + } + + return nil +} \ No newline at end of file