mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-13 03:47:02 +00:00
126 lines
3.8 KiB
Go
126 lines
3.8 KiB
Go
package router
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"time"
|
|
|
|
"github.com/go-redis/redis/v8"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/nais/wonderwall/pkg/jwt"
|
|
"github.com/nais/wonderwall/pkg/session"
|
|
)
|
|
|
|
// localSessionID prefixes the given `sid` or `session_state` with the given client ID to prevent key collisions.
|
|
// `sid` or `session_state` is a key that refers to the user's unique SSO session at the Identity Provider, and the same key is present
|
|
// in all tokens acquired by any Relying Party (such as Wonderwall) during that session.
|
|
// Thus, we cannot assume that the value of `sid` or `session_state` to uniquely identify the pair of (user, application session)
|
|
// if using a shared session store.
|
|
func (h *Handler) localSessionID(sessionID string) string {
|
|
return fmt.Sprintf("%s:%s:%s", h.Config.OpenID.Provider, h.Provider.GetClientConfiguration().GetClientID(), sessionID)
|
|
}
|
|
|
|
func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (*session.Data, error) {
|
|
sessionID, err := h.getDecryptedCookie(r, SessionCookieName)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("no session cookie: %w", err)
|
|
}
|
|
|
|
sessionData, err := h.getSession(r.Context(), sessionID)
|
|
if err == nil {
|
|
h.DeleteSessionFallback(w, r)
|
|
return sessionData, nil
|
|
}
|
|
|
|
if errors.Is(err, redis.Nil) {
|
|
return nil, fmt.Errorf("session not found in store: %w", err)
|
|
}
|
|
|
|
log.Warnf("get session: store is unavailable: %+v; using cookie fallback", err)
|
|
|
|
fallbackSessionData, err := h.GetSessionFallback(r)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("getting fallback session: %w", err)
|
|
}
|
|
|
|
return fallbackSessionData, nil
|
|
}
|
|
|
|
func (h *Handler) getSession(ctx context.Context, sessionID string) (*session.Data, error) {
|
|
encryptedSessionData, err := h.Sessions.Read(ctx, sessionID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading session data from store: %w", err)
|
|
}
|
|
|
|
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decrypting session data: %w", err)
|
|
}
|
|
|
|
return sessionData, nil
|
|
}
|
|
|
|
func (h *Handler) getSessionLifetime(accessToken *jwt.AccessToken) time.Duration {
|
|
defaultSessionLifetime := h.Config.SessionMaxLifetime
|
|
|
|
tokenDuration := accessToken.GetExpiration().Sub(time.Now())
|
|
|
|
if tokenDuration <= defaultSessionLifetime {
|
|
return tokenDuration
|
|
}
|
|
|
|
return defaultSessionLifetime
|
|
}
|
|
|
|
func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *jwt.Tokens, params url.Values) error {
|
|
externalSessionID, err := NewSessionID(h.Provider.GetOpenIDConfiguration(), tokens.IDToken, params)
|
|
if err != nil {
|
|
return fmt.Errorf("generating session ID: %w", err)
|
|
}
|
|
|
|
sessionLifetime := h.getSessionLifetime(tokens.AccessToken)
|
|
opts := h.CookieOptions.WithExpiresIn(sessionLifetime)
|
|
|
|
sessionID := h.localSessionID(externalSessionID)
|
|
err = h.setEncryptedCookie(w, SessionCookieName, sessionID, opts)
|
|
if err != nil {
|
|
return fmt.Errorf("setting session cookie: %w", err)
|
|
}
|
|
|
|
sessionData := session.NewData(externalSessionID, tokens)
|
|
|
|
encryptedSessionData, err := sessionData.Encrypt(h.Crypter)
|
|
if err != nil {
|
|
return fmt.Errorf("encrypting session data: %w", err)
|
|
}
|
|
|
|
err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime)
|
|
if err == nil {
|
|
h.DeleteSessionFallback(w, r)
|
|
return nil
|
|
}
|
|
|
|
log.Warnf("create session: store is unavailable: %+v; using cookie fallback", err)
|
|
|
|
err = h.SetSessionFallback(w, sessionData, sessionLifetime)
|
|
if err != nil {
|
|
return fmt.Errorf("writing session to fallback store: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *Handler) destroySession(w http.ResponseWriter, r *http.Request, sessionID string) error {
|
|
err := h.Sessions.Delete(r.Context(), sessionID)
|
|
if err != nil {
|
|
return fmt.Errorf("deleting session from store: %w", err)
|
|
}
|
|
|
|
h.DeleteSessionFallback(w, r)
|
|
return nil
|
|
}
|