mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-10 18:37:00 +00:00
refactor: robustify logout routes
Co-authored-by: Morten Lied Johansen <morten.lied.johansen@nav.no> Co-authored-by: Sindre Rødseth Hansen <sindre.rodseth.hansen@nav.no>
This commit is contained in:
@@ -326,20 +326,6 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// Logout triggers self-initiated for the current user
|
||||
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
var idToken string
|
||||
|
||||
sess, err := h.getSessionFromCookie(r)
|
||||
if err == nil && sess != nil && sess.OAuth2Token != nil {
|
||||
idToken = sess.IDTokenSerialized
|
||||
err = h.Sessions.Delete(r.Context(), h.localSessionID(sess.ExternalSessionID))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
h.deleteCookie(w, h.GetSessionCookieName())
|
||||
}
|
||||
|
||||
u, err := url.Parse(h.Config.WellKnown.EndSessionEndpoint)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
@@ -347,6 +333,21 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var idToken string
|
||||
|
||||
sess, err := h.getSessionFromCookie(r)
|
||||
if err == nil && sess != nil && sess.OAuth2Token != nil {
|
||||
idToken = sess.IDTokenSerialized
|
||||
err = h.destroySession(r, h.localSessionID(sess.ExternalSessionID))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.deleteCookie(w, h.GetSessionCookieName())
|
||||
|
||||
v := u.Query()
|
||||
v.Add("post_logout_redirect_uri", PostLogoutRedirectURI(r, h.Config.PostLogoutRedirectURI))
|
||||
|
||||
@@ -362,42 +363,24 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
|
||||
params := r.URL.Query()
|
||||
|
||||
iss := params.Get("iss")
|
||||
sid := params.Get("sid")
|
||||
|
||||
if len(sid) == 0 || len(iss) == 0 {
|
||||
if len(sid) == 0 {
|
||||
log.Error("sid not set for front-channel logout")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := h.localSessionID(sid)
|
||||
|
||||
// From here on, check that 'iss' from request matches data found in access token.
|
||||
accessToken, err := h.getAccessTokenFromSession(r, sessionID)
|
||||
if accessToken == nil {
|
||||
// Can't remove session because it doesn't exist. Maybe it was garbage collected.
|
||||
// We regard this as a redundant logout and return 200 OK.
|
||||
return
|
||||
}
|
||||
err := h.destroySession(r, sessionID)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
// Session is already destroyed at the OP and is highly unlikely to be used again.
|
||||
}
|
||||
|
||||
err = jwt.Validate(accessToken, jwt.WithClaimValue("iss", iss))
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// All verified; delete session.
|
||||
err = h.Sessions.Delete(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// Unconditionally destroy all local references to the session.
|
||||
h.deleteCookie(w, h.GetSessionCookieName())
|
||||
}
|
||||
|
||||
func New(handler *Handler, prefixes []string) chi.Router {
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"github.com/nais/wonderwall/pkg/token"
|
||||
@@ -27,6 +29,10 @@ func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) {
|
||||
|
||||
encryptedSessionData, err := h.Sessions.Read(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
// TODO: attempt to fetch encrypted data from fallback session cookie (if set)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("reading session from store: %w", err)
|
||||
}
|
||||
|
||||
@@ -81,6 +87,8 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external
|
||||
|
||||
err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime)
|
||||
if err != nil {
|
||||
// TODO: fallback to writing encrypted session data to cookie
|
||||
|
||||
return fmt.Errorf("writing session to store: %w", err)
|
||||
}
|
||||
|
||||
@@ -90,8 +98,13 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external
|
||||
func (h *Handler) getAccessTokenFromSession(r *http.Request, sessionID string) (jwt.Token, error) {
|
||||
encryptedSession, err := h.Sessions.Read(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
// Session not found; ignoring
|
||||
return nil, nil
|
||||
if errors.Is(err, redis.Nil) {
|
||||
// Session not found; ignoring
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TODO: fetch from fallback session cookie (if set)
|
||||
return nil, fmt.Errorf("fetching session from store: %w", err)
|
||||
}
|
||||
|
||||
sessionData, err := encryptedSession.Decrypt(h.Crypter)
|
||||
@@ -107,3 +120,13 @@ func (h *Handler) getAccessTokenFromSession(r *http.Request, sessionID string) (
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (h *Handler) destroySession(r *http.Request, sessionID string) error {
|
||||
err := h.Sessions.Delete(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting session from store: %w", err)
|
||||
}
|
||||
|
||||
// TODO: delete fallback session cookie (if set)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,24 +8,21 @@ import (
|
||||
)
|
||||
|
||||
type redisSessionStore struct {
|
||||
client redis.Cmdable
|
||||
client redis.Cmdable
|
||||
}
|
||||
|
||||
var _ Store = &redisSessionStore{}
|
||||
|
||||
func NewRedis(client redis.Cmdable) Store {
|
||||
return &redisSessionStore{
|
||||
client: client,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedData, error) {
|
||||
encryptedData := &EncryptedData{}
|
||||
err := metrics.ObserveRedisLatency("Read", func() error {
|
||||
var err error
|
||||
status := s.client.Get(ctx, key)
|
||||
err = status.Scan(encryptedData)
|
||||
return err
|
||||
return s.client.Get(ctx, key).Scan(encryptedData)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -36,14 +33,12 @@ func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedDat
|
||||
|
||||
func (s *redisSessionStore) Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error {
|
||||
return metrics.ObserveRedisLatency("Write", func() error {
|
||||
status := s.client.Set(ctx, key, value, expiration)
|
||||
return status.Err()
|
||||
return s.client.Set(ctx, key, value, expiration).Err()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error {
|
||||
return metrics.ObserveRedisLatency("Delete", func() error {
|
||||
status := s.client.Del(ctx, keys...)
|
||||
return status.Err()
|
||||
return s.client.Del(ctx, keys...).Err()
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user