From 25221added1aaf470b274223264bd5fe7ef11954 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Wed, 29 Sep 2021 13:27:30 +0200 Subject: [PATCH] rename callbackparams to logincookie for clarity, ensure logincookie is deleted when no longer needed --- pkg/router/cookies.go | 40 +++++++++++++++++++++++++++++++-------- pkg/router/router.go | 28 +++++++-------------------- pkg/router/router_test.go | 9 +++++++++ 3 files changed, 48 insertions(+), 29 deletions(-) diff --git a/pkg/router/cookies.go b/pkg/router/cookies.go index 65f9442..87d890b 100644 --- a/pkg/router/cookies.go +++ b/pkg/router/cookies.go @@ -8,40 +8,64 @@ import ( "time" ) +const ( + LoginCookieLifetime = 2 * time.Minute + + SessionCookieNameTemplate = "io.nais.wonderwall.%s.session" + LoginCookieNameTemplate = "io.nais.wonderwall.%s.callback" +) + type Cookie struct { name string value string expiresIn time.Duration } -type CallbackParams struct { +type LoginCookie struct { State string `json:"state"` Nonce string `json:"nonce"` CodeVerifier string `json:"code_verifier"` Referer string `json:"referer"` } -func (h *Handler) getCallbackCookieName() string { - return fmt.Sprintf(CallbackCookieNameTemplate, h.Config.ClientID) +func (h *Handler) GetLoginCookieName() string { + return fmt.Sprintf(LoginCookieNameTemplate, h.Config.ClientID) } func (h *Handler) GetSessionCookieName() string { return fmt.Sprintf(SessionCookieNameTemplate, h.Config.ClientID) } -func (h *Handler) getCallbackParams(r *http.Request) (*CallbackParams, error) { - callbackCookieString, err := h.getEncryptedCookie(r, h.getCallbackCookieName()) +func (h *Handler) getLoginCookie(w http.ResponseWriter, r *http.Request) (*LoginCookie, error) { + loginCookieJson, err := h.getEncryptedCookie(r, h.GetLoginCookieName()) if err != nil { return nil, err } - var callbackParams CallbackParams - err = json.Unmarshal([]byte(callbackCookieString), &callbackParams) + var loginCookie LoginCookie + err = json.Unmarshal([]byte(loginCookieJson), &loginCookie) if err != nil { return nil, err } - return &callbackParams, nil + // delete cookie as we no longer need it + h.deleteCookie(w, h.GetLoginCookieName()) + + return &loginCookie, nil +} + +func (h *Handler) setLoginCookie(w http.ResponseWriter, loginCookie *LoginCookie) error { + loginCookieJson, err := json.Marshal(loginCookie) + if err != nil { + return fmt.Errorf("marshalling login cookie: %w", err) + } + + err = h.setEncryptedCookie(w, h.GetLoginCookieName(), string(loginCookieJson), LoginCookieLifetime) + if err != nil { + return err + } + + return nil } func (h *Handler) setEncryptedCookie(w http.ResponseWriter, key string, plaintext string, expiresIn time.Duration) error { diff --git a/pkg/router/router.go b/pkg/router/router.go index ce8f63d..87a1746 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "crypto/sha256" "encoding/base64" - "encoding/json" "errors" "fmt" "io" @@ -31,10 +30,6 @@ import ( ) const ( - LoginCookieLifetime = 2 * time.Minute - SessionCookieNameTemplate = "io.nais.wonderwall.%s.session" - CallbackCookieNameTemplate = "io.nais.wonderwall.%s.callback" - RedirectURLParameter = "redirect" SecurityLevelURLParameter = "level" LocaleURLParameter = "locale" @@ -215,21 +210,12 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { return } - callbackCookies := &CallbackParams{ + err = h.setLoginCookie(w, &LoginCookie{ State: params.state, Nonce: params.nonce, CodeVerifier: params.codeVerifier, Referer: CanonicalRedirectURL(r), - } - - jsonString, err := json.Marshal(callbackCookies) - if err != nil { - log.Error(err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - err = h.setEncryptedCookie(w, h.getCallbackCookieName(), string(jsonString), LoginCookieLifetime) + }) if err != nil { log.Error(err) w.WriteHeader(http.StatusInternalServerError) @@ -240,7 +226,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { } func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { - cookies, err := h.getCallbackParams(r) + loginCookie, err := h.getLoginCookie(w, r) if err != nil { log.Error(err) w.WriteHeader(http.StatusUnauthorized) @@ -256,7 +242,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - if params.Get("state") != cookies.State { + if params.Get("state") != loginCookie.State { log.Error("state parameter mismatch") w.WriteHeader(http.StatusUnauthorized) return @@ -270,7 +256,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { } opts := []oauth2.AuthCodeOption{ - oauth2.SetAuthURLParam("code_verifier", cookies.CodeVerifier), + oauth2.SetAuthURLParam("code_verifier", loginCookie.CodeVerifier), oauth2.SetAuthURLParam("client_assertion", assertion), oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"), } @@ -291,7 +277,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { validateOpts := []jwt.ValidateOption{ jwt.WithAudience(h.Config.ClientID), - jwt.WithClaimValue("nonce", cookies.Nonce), + jwt.WithClaimValue("nonce", loginCookie.Nonce), jwt.WithIssuer(h.Config.WellKnown.Issuer), jwt.WithAcceptableSkew(5 * time.Second), jwt.WithRequiredClaim("sid"), @@ -334,7 +320,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - http.Redirect(w, r, cookies.Referer, http.StatusTemporaryRedirect) + http.Redirect(w, r, loginCookie.Referer, http.StatusTemporaryRedirect) } // Proxy all requests upstream diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 303ebce..6ac9b14 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -247,14 +247,23 @@ func TestHandler_Callback_and_Logout(t *testing.T) { cookies := client.Jar.Cookies(callbackURL) var sessionCookie *http.Cookie + var loginCookie *http.Cookie for _, cookie := range cookies { if cookie.Name == h.GetSessionCookieName() { sessionCookie = cookie } + + if cookie.Name == h.GetLoginCookieName() { + loginCookie = cookie + } } assert.NotNil(t, sessionCookie) + assert.NotNil(t, loginCookie) + assert.Empty(t, loginCookie.Value) + assert.True(t, loginCookie.Expires.Before(time.Now())) + // Request self-initiated logout req, err = client.Get(server.URL + "/oauth2/logout") assert.NoError(t, err)