mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-20 07:12:48 +00:00
rename callbackparams to logincookie for clarity, ensure logincookie is deleted when no longer needed
This commit is contained in:
@@ -8,40 +8,64 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
LoginCookieLifetime = 2 * time.Minute
|
||||
|
||||
SessionCookieNameTemplate = "io.nais.wonderwall.%s.session"
|
||||
LoginCookieNameTemplate = "io.nais.wonderwall.%s.callback"
|
||||
)
|
||||
|
||||
type Cookie struct {
|
||||
name string
|
||||
value string
|
||||
expiresIn time.Duration
|
||||
}
|
||||
|
||||
type CallbackParams struct {
|
||||
type LoginCookie struct {
|
||||
State string `json:"state"`
|
||||
Nonce string `json:"nonce"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
Referer string `json:"referer"`
|
||||
}
|
||||
|
||||
func (h *Handler) getCallbackCookieName() string {
|
||||
return fmt.Sprintf(CallbackCookieNameTemplate, h.Config.ClientID)
|
||||
func (h *Handler) GetLoginCookieName() string {
|
||||
return fmt.Sprintf(LoginCookieNameTemplate, h.Config.ClientID)
|
||||
}
|
||||
|
||||
func (h *Handler) GetSessionCookieName() string {
|
||||
return fmt.Sprintf(SessionCookieNameTemplate, h.Config.ClientID)
|
||||
}
|
||||
|
||||
func (h *Handler) getCallbackParams(r *http.Request) (*CallbackParams, error) {
|
||||
callbackCookieString, err := h.getEncryptedCookie(r, h.getCallbackCookieName())
|
||||
func (h *Handler) getLoginCookie(w http.ResponseWriter, r *http.Request) (*LoginCookie, error) {
|
||||
loginCookieJson, err := h.getEncryptedCookie(r, h.GetLoginCookieName())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var callbackParams CallbackParams
|
||||
err = json.Unmarshal([]byte(callbackCookieString), &callbackParams)
|
||||
var loginCookie LoginCookie
|
||||
err = json.Unmarshal([]byte(loginCookieJson), &loginCookie)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &callbackParams, nil
|
||||
// delete cookie as we no longer need it
|
||||
h.deleteCookie(w, h.GetLoginCookieName())
|
||||
|
||||
return &loginCookie, nil
|
||||
}
|
||||
|
||||
func (h *Handler) setLoginCookie(w http.ResponseWriter, loginCookie *LoginCookie) error {
|
||||
loginCookieJson, err := json.Marshal(loginCookie)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshalling login cookie: %w", err)
|
||||
}
|
||||
|
||||
err = h.setEncryptedCookie(w, h.GetLoginCookieName(), string(loginCookieJson), LoginCookieLifetime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) setEncryptedCookie(w http.ResponseWriter, key string, plaintext string, expiresIn time.Duration) error {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -31,10 +30,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
LoginCookieLifetime = 2 * time.Minute
|
||||
SessionCookieNameTemplate = "io.nais.wonderwall.%s.session"
|
||||
CallbackCookieNameTemplate = "io.nais.wonderwall.%s.callback"
|
||||
|
||||
RedirectURLParameter = "redirect"
|
||||
SecurityLevelURLParameter = "level"
|
||||
LocaleURLParameter = "locale"
|
||||
@@ -215,21 +210,12 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
callbackCookies := &CallbackParams{
|
||||
err = h.setLoginCookie(w, &LoginCookie{
|
||||
State: params.state,
|
||||
Nonce: params.nonce,
|
||||
CodeVerifier: params.codeVerifier,
|
||||
Referer: CanonicalRedirectURL(r),
|
||||
}
|
||||
|
||||
jsonString, err := json.Marshal(callbackCookies)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.setEncryptedCookie(w, h.getCallbackCookieName(), string(jsonString), LoginCookieLifetime)
|
||||
})
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
@@ -240,7 +226,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
cookies, err := h.getCallbackParams(r)
|
||||
loginCookie, err := h.getLoginCookie(w, r)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
@@ -256,7 +242,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if params.Get("state") != cookies.State {
|
||||
if params.Get("state") != loginCookie.State {
|
||||
log.Error("state parameter mismatch")
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
@@ -270,7 +256,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
opts := []oauth2.AuthCodeOption{
|
||||
oauth2.SetAuthURLParam("code_verifier", cookies.CodeVerifier),
|
||||
oauth2.SetAuthURLParam("code_verifier", loginCookie.CodeVerifier),
|
||||
oauth2.SetAuthURLParam("client_assertion", assertion),
|
||||
oauth2.SetAuthURLParam("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"),
|
||||
}
|
||||
@@ -291,7 +277,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
validateOpts := []jwt.ValidateOption{
|
||||
jwt.WithAudience(h.Config.ClientID),
|
||||
jwt.WithClaimValue("nonce", cookies.Nonce),
|
||||
jwt.WithClaimValue("nonce", loginCookie.Nonce),
|
||||
jwt.WithIssuer(h.Config.WellKnown.Issuer),
|
||||
jwt.WithAcceptableSkew(5 * time.Second),
|
||||
jwt.WithRequiredClaim("sid"),
|
||||
@@ -334,7 +320,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, cookies.Referer, http.StatusTemporaryRedirect)
|
||||
http.Redirect(w, r, loginCookie.Referer, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
// Proxy all requests upstream
|
||||
|
||||
@@ -247,14 +247,23 @@ func TestHandler_Callback_and_Logout(t *testing.T) {
|
||||
|
||||
cookies := client.Jar.Cookies(callbackURL)
|
||||
var sessionCookie *http.Cookie
|
||||
var loginCookie *http.Cookie
|
||||
for _, cookie := range cookies {
|
||||
if cookie.Name == h.GetSessionCookieName() {
|
||||
sessionCookie = cookie
|
||||
}
|
||||
|
||||
if cookie.Name == h.GetLoginCookieName() {
|
||||
loginCookie = cookie
|
||||
}
|
||||
}
|
||||
|
||||
assert.NotNil(t, sessionCookie)
|
||||
|
||||
assert.NotNil(t, loginCookie)
|
||||
assert.Empty(t, loginCookie.Value)
|
||||
assert.True(t, loginCookie.Expires.Before(time.Now()))
|
||||
|
||||
// Request self-initiated logout
|
||||
req, err = client.Get(server.URL + "/oauth2/logout")
|
||||
assert.NoError(t, err)
|
||||
|
||||
Reference in New Issue
Block a user