diff --git a/pkg/auth/validate.go b/pkg/auth/validate.go index 0b8897b..eddcedf 100644 --- a/pkg/auth/validate.go +++ b/pkg/auth/validate.go @@ -8,20 +8,20 @@ import ( "golang.org/x/oauth2" ) -func ValidateIdToken(ctx context.Context, verifier *oidc.IDTokenVerifier, token *oauth2.Token, nonce string) error { +func ValidateIdToken(ctx context.Context, verifier *oidc.IDTokenVerifier, token *oauth2.Token, nonce string) (*oidc.IDToken, error) { raw, ok := token.Extra("id_token").(string) if !ok { - return fmt.Errorf("missing id_token in token response") + return nil, fmt.Errorf("missing id_token in token response") } idToken, err := verifier.Verify(ctx, raw) if err != nil { - return err + return nil, err } if idToken.Nonce != nonce { - return fmt.Errorf("nonce does not match") + return nil, fmt.Errorf("nonce does not match") } - return nil + return idToken, nil } diff --git a/pkg/router/router.go b/pkg/router/router.go index 2665e4c..3b89a7d 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -28,6 +28,7 @@ import ( const ( SessionMaxLifetime = time.Hour + LoginCookieLifetime = 10 * time.Minute ScopeOpenID = "openid" SessionCookieName = "io.nais.wonderwall.session" StateCookieName = "io.nais.wonderwall.state" @@ -123,28 +124,19 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { return } - http.SetCookie(w, &http.Cookie{ - Name: SessionCookieName, - Value: params.session, - Path: "/", - Expires: time.Now().Add(SessionMaxLifetime), - Secure: true, - SameSite: http.SameSiteLaxMode, - }) - - err = h.setEncryptedCookie(w, StateCookieName, params.state) + err = h.setEncryptedCookie(w, StateCookieName, params.state, LoginCookieLifetime) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } - err = h.setEncryptedCookie(w, NonceCookieName, params.nonce) + err = h.setEncryptedCookie(w, NonceCookieName, params.nonce, LoginCookieLifetime) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } - err = h.setEncryptedCookie(w, CodeVerifierCookieName, params.codeVerifier) + err = h.setEncryptedCookie(w, CodeVerifierCookieName, params.codeVerifier, LoginCookieLifetime) if err != nil { w.WriteHeader(http.StatusInternalServerError) return @@ -192,7 +184,7 @@ func (h *Handler) SignedJWTProfileAssertion(expiration time.Duration) (string, e return result.CompactSerialize() } -func (h *Handler) setEncryptedCookie(w http.ResponseWriter, key string, plaintext string) error { +func (h *Handler) setEncryptedCookie(w http.ResponseWriter, key string, plaintext string, expiresIn time.Duration) error { ciphertext, err := h.Crypter.Encrypt([]byte(plaintext)) if err != nil { return fmt.Errorf("unable to encrypt cookie '%s': %w", key, err) @@ -201,7 +193,7 @@ func (h *Handler) setEncryptedCookie(w http.ResponseWriter, key string, plaintex http.SetCookie(w, &http.Cookie{ Name: key, Value: base64.StdEncoding.EncodeToString(ciphertext), - Expires: time.Now().Add(10 * time.Minute), + Expires: time.Now().Add(expiresIn), Secure: true, SameSite: http.SameSiteLaxMode, }) @@ -229,13 +221,6 @@ func (h *Handler) getEncryptedCookie(r *http.Request, key string) (string, error } func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { - sessionCookie, err := r.Cookie(SessionCookieName) - if err != nil { - log.Error(err) - w.WriteHeader(http.StatusUnauthorized) - return - } - state, err := h.getEncryptedCookie(r, StateCookieName) if err != nil { log.Error(err) @@ -290,14 +275,29 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - err = auth.ValidateIdToken(r.Context(), h.IdTokenVerifier, token, nonce) + idToken, err := auth.ValidateIdToken(r.Context(), h.IdTokenVerifier, token, nonce) if err != nil { log.Error(err) w.WriteHeader(http.StatusUnauthorized) return } + var claims struct { + SessionID string `json:"sid"` + } + if err := idToken.Claims(&claims); err != nil { + log.Error(err) + w.WriteHeader(http.StatusUnauthorized) + return + } - h.sessions[sessionCookie.Value] = token + err = h.setEncryptedCookie(w, SessionCookieName, claims.SessionID, SessionMaxLifetime) + if err != nil { + log.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + h.sessions[claims.SessionID] = token http.Redirect(w, r, "/", http.StatusTemporaryRedirect) } @@ -309,15 +309,15 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { upstreamRequest := r.Clone(ctx) // Get credentials from session cache - sessionCookie, err := r.Cookie(SessionCookieName) + sessionID, err := h.getEncryptedCookie(r, SessionCookieName) if err != nil { log.Tracef("no session cookie; should redirect to /oauth2/login") http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) return } - token, ok := h.sessions[sessionCookie.Value] + token, ok := h.sessions[sessionID] if !ok { - log.Tracef("no token stored for session %s; needs garbage collection client side", sessionCookie.Value) + log.Tracef("no token stored for session %s; needs garbage collection client side", sessionID) http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) return } @@ -363,27 +363,25 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { // Logout triggers self-initiated for the current user func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { - sessionCookie, err := r.Cookie(SessionCookieName) + sessionID, err := h.getEncryptedCookie(r, SessionCookieName) if err != nil { log.Tracef("no session cookie; should redirect to /oauth2/login") http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) return } - _, ok := h.sessions[sessionCookie.Value] + _, ok := h.sessions[sessionID] if !ok { - log.Tracef("no token stored for session %s; needs garbage collection client side", sessionCookie.Value) + log.Tracef("no token stored for session %s; needs garbage collection client side", sessionID) http.Redirect(w, r, "/oauth2/login", http.StatusTemporaryRedirect) return } - delete(h.sessions, sessionCookie.Value) + delete(h.sessions, sessionID) http.SetCookie(w, &http.Cookie{ Name: SessionCookieName, - Value: "", Path: "/", - Expires: time.Unix(0,0), Secure: true, SameSite: http.SameSiteLaxMode, })