rename callbackparams to logincookie for clarity, ensure logincookie is deleted when no longer needed

This commit is contained in:
Trong Huu Nguyen
2021-09-29 13:27:30 +02:00
parent b60db493ac
commit 25221added
3 changed files with 48 additions and 29 deletions

View File

@@ -8,40 +8,64 @@ import (
"time"
)
const (
LoginCookieLifetime = 2 * time.Minute
SessionCookieNameTemplate = "io.nais.wonderwall.%s.session"
LoginCookieNameTemplate = "io.nais.wonderwall.%s.callback"
)
type Cookie struct {
name string
value string
expiresIn time.Duration
}
type CallbackParams struct {
type LoginCookie struct {
State string `json:"state"`
Nonce string `json:"nonce"`
CodeVerifier string `json:"code_verifier"`
Referer string `json:"referer"`
}
func (h *Handler) getCallbackCookieName() string {
return fmt.Sprintf(CallbackCookieNameTemplate, h.Config.ClientID)
func (h *Handler) GetLoginCookieName() string {
return fmt.Sprintf(LoginCookieNameTemplate, h.Config.ClientID)
}
func (h *Handler) GetSessionCookieName() string {
return fmt.Sprintf(SessionCookieNameTemplate, h.Config.ClientID)
}
func (h *Handler) getCallbackParams(r *http.Request) (*CallbackParams, error) {
callbackCookieString, err := h.getEncryptedCookie(r, h.getCallbackCookieName())
func (h *Handler) getLoginCookie(w http.ResponseWriter, r *http.Request) (*LoginCookie, error) {
loginCookieJson, err := h.getEncryptedCookie(r, h.GetLoginCookieName())
if err != nil {
return nil, err
}
var callbackParams CallbackParams
err = json.Unmarshal([]byte(callbackCookieString), &callbackParams)
var loginCookie LoginCookie
err = json.Unmarshal([]byte(loginCookieJson), &loginCookie)
if err != nil {
return nil, err
}
return &callbackParams, nil
// delete cookie as we no longer need it
h.deleteCookie(w, h.GetLoginCookieName())
return &loginCookie, nil
}
func (h *Handler) setLoginCookie(w http.ResponseWriter, loginCookie *LoginCookie) error {
loginCookieJson, err := json.Marshal(loginCookie)
if err != nil {
return fmt.Errorf("marshalling login cookie: %w", err)
}
err = h.setEncryptedCookie(w, h.GetLoginCookieName(), string(loginCookieJson), LoginCookieLifetime)
if err != nil {
return err
}
return nil
}
func (h *Handler) setEncryptedCookie(w http.ResponseWriter, key string, plaintext string, expiresIn time.Duration) error {

View File

@@ -5,7 +5,6 @@ import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
@@ -31,10 +30,6 @@ import (
)
const (
LoginCookieLifetime = 2 * time.Minute
SessionCookieNameTemplate = "io.nais.wonderwall.%s.session"
CallbackCookieNameTemplate = "io.nais.wonderwall.%s.callback"
RedirectURLParameter = "redirect"
SecurityLevelURLParameter = "level"
LocaleURLParameter = "locale"
@@ -215,21 +210,12 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
return
}
callbackCookies := &CallbackParams{
err = h.setLoginCookie(w, &LoginCookie{
State: params.state,
Nonce: params.nonce,
CodeVerifier: params.codeVerifier,
Referer: CanonicalRedirectURL(r),
}
jsonString, err := json.Marshal(callbackCookies)
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
err = h.setEncryptedCookie(w, h.getCallbackCookieName(), string(jsonString), LoginCookieLifetime)
})
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
@@ -240,7 +226,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
}
func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
cookies, err := h.getCallbackParams(r)
loginCookie, err := h.getLoginCookie(w, r)
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusUnauthorized)
@@ -256,7 +242,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
if params.Get("state") != cookies.State {
if params.Get("state") != loginCookie.State {
log.Error("state parameter mismatch")
w.WriteHeader(http.StatusUnauthorized)
return
@@ -270,7 +256,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
}
opts := []oauth2.AuthCodeOption{
oauth2.SetAuthURLParam("code_verifier", cookies.CodeVerifier),
oauth2.SetAuthURLParam("code_verifier", loginCookie.CodeVerifier),
oauth2.SetAuthURLParam("client_assertion", assertion),
oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"),
}
@@ -291,7 +277,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
validateOpts := []jwt.ValidateOption{
jwt.WithAudience(h.Config.ClientID),
jwt.WithClaimValue("nonce", cookies.Nonce),
jwt.WithClaimValue("nonce", loginCookie.Nonce),
jwt.WithIssuer(h.Config.WellKnown.Issuer),
jwt.WithAcceptableSkew(5 * time.Second),
jwt.WithRequiredClaim("sid"),
@@ -334,7 +320,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
http.Redirect(w, r, cookies.Referer, http.StatusTemporaryRedirect)
http.Redirect(w, r, loginCookie.Referer, http.StatusTemporaryRedirect)
}
// Proxy all requests upstream

View File

@@ -247,14 +247,23 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
cookies := client.Jar.Cookies(callbackURL)
var sessionCookie *http.Cookie
var loginCookie *http.Cookie
for _, cookie := range cookies {
if cookie.Name == h.GetSessionCookieName() {
sessionCookie = cookie
}
if cookie.Name == h.GetLoginCookieName() {
loginCookie = cookie
}
}
assert.NotNil(t, sessionCookie)
assert.NotNil(t, loginCookie)
assert.Empty(t, loginCookie.Value)
assert.True(t, loginCookie.Expires.Before(time.Now()))
// Request self-initiated logout
req, err = client.Get(server.URL + "/oauth2/logout")
assert.NoError(t, err)