refactor: robustify logout routes

Co-authored-by: Morten Lied Johansen <morten.lied.johansen@nav.no>
Co-authored-by: Sindre Rødseth Hansen <sindre.rodseth.hansen@nav.no>
This commit is contained in:
Trong Huu Nguyen
2021-10-01 09:35:21 +02:00
parent cc8ba980ca
commit 03eec9d2b8
3 changed files with 51 additions and 50 deletions

View File

@@ -326,20 +326,6 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
// Logout triggers self-initiated for the current user
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
var idToken string
sess, err := h.getSessionFromCookie(r)
if err == nil && sess != nil && sess.OAuth2Token != nil {
idToken = sess.IDTokenSerialized
err = h.Sessions.Delete(r.Context(), h.localSessionID(sess.ExternalSessionID))
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
h.deleteCookie(w, h.GetSessionCookieName())
}
u, err := url.Parse(h.Config.WellKnown.EndSessionEndpoint)
if err != nil {
log.Error(err)
@@ -347,6 +333,21 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
return
}
var idToken string
sess, err := h.getSessionFromCookie(r)
if err == nil && sess != nil && sess.OAuth2Token != nil {
idToken = sess.IDTokenSerialized
err = h.destroySession(r, h.localSessionID(sess.ExternalSessionID))
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}
h.deleteCookie(w, h.GetSessionCookieName())
v := u.Query()
v.Add("post_logout_redirect_uri", PostLogoutRedirectURI(r, h.Config.PostLogoutRedirectURI))
@@ -362,42 +363,24 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
params := r.URL.Query()
iss := params.Get("iss")
sid := params.Get("sid")
if len(sid) == 0 || len(iss) == 0 {
if len(sid) == 0 {
log.Error("sid not set for front-channel logout")
w.WriteHeader(http.StatusBadRequest)
return
}
sessionID := h.localSessionID(sid)
// From here on, check that 'iss' from request matches data found in access token.
accessToken, err := h.getAccessTokenFromSession(r, sessionID)
if accessToken == nil {
// Can't remove session because it doesn't exist. Maybe it was garbage collected.
// We regard this as a redundant logout and return 200 OK.
return
}
err := h.destroySession(r, sessionID)
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
return
// Session is already destroyed at the OP and is highly unlikely to be used again.
}
err = jwt.Validate(accessToken, jwt.WithClaimValue("iss", iss))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
// All verified; delete session.
err = h.Sessions.Delete(r.Context(), sessionID)
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
// Unconditionally destroy all local references to the session.
h.deleteCookie(w, h.GetSessionCookieName())
}
func New(handler *Handler, prefixes []string) chi.Router {

View File

@@ -1,7 +1,9 @@
package router
import (
"errors"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/lestrrat-go/jwx/jwt"
"github.com/nais/wonderwall/pkg/session"
"github.com/nais/wonderwall/pkg/token"
@@ -27,6 +29,10 @@ func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) {
encryptedSessionData, err := h.Sessions.Read(r.Context(), sessionID)
if err != nil {
if errors.Is(err, redis.Nil) {
// TODO: attempt to fetch encrypted data from fallback session cookie (if set)
}
return nil, fmt.Errorf("reading session from store: %w", err)
}
@@ -81,6 +87,8 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external
err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime)
if err != nil {
// TODO: fallback to writing encrypted session data to cookie
return fmt.Errorf("writing session to store: %w", err)
}
@@ -90,8 +98,13 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external
func (h *Handler) getAccessTokenFromSession(r *http.Request, sessionID string) (jwt.Token, error) {
encryptedSession, err := h.Sessions.Read(r.Context(), sessionID)
if err != nil {
// Session not found; ignoring
return nil, nil
if errors.Is(err, redis.Nil) {
// Session not found; ignoring
return nil, nil
}
// TODO: fetch from fallback session cookie (if set)
return nil, fmt.Errorf("fetching session from store: %w", err)
}
sessionData, err := encryptedSession.Decrypt(h.Crypter)
@@ -107,3 +120,13 @@ func (h *Handler) getAccessTokenFromSession(r *http.Request, sessionID string) (
return accessToken, nil
}
func (h *Handler) destroySession(r *http.Request, sessionID string) error {
err := h.Sessions.Delete(r.Context(), sessionID)
if err != nil {
return fmt.Errorf("deleting session from store: %w", err)
}
// TODO: delete fallback session cookie (if set)
return nil
}

View File

@@ -8,24 +8,21 @@ import (
)
type redisSessionStore struct {
client redis.Cmdable
client redis.Cmdable
}
var _ Store = &redisSessionStore{}
func NewRedis(client redis.Cmdable) Store {
return &redisSessionStore{
client: client,
client: client,
}
}
func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedData, error) {
encryptedData := &EncryptedData{}
err := metrics.ObserveRedisLatency("Read", func() error {
var err error
status := s.client.Get(ctx, key)
err = status.Scan(encryptedData)
return err
return s.client.Get(ctx, key).Scan(encryptedData)
})
if err != nil {
return nil, err
@@ -36,14 +33,12 @@ func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedDat
func (s *redisSessionStore) Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error {
return metrics.ObserveRedisLatency("Write", func() error {
status := s.client.Set(ctx, key, value, expiration)
return status.Err()
return s.client.Set(ctx, key, value, expiration).Err()
})
}
func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error {
return metrics.ObserveRedisLatency("Delete", func() error {
status := s.client.Del(ctx, keys...)
return status.Err()
return s.client.Del(ctx, keys...).Err()
})
}