mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-07 00:46:56 +00:00
398 lines
11 KiB
Go
398 lines
11 KiB
Go
package session
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"time"
|
|
|
|
"github.com/sethvargo/go-retry"
|
|
|
|
"github.com/nais/wonderwall/pkg/config"
|
|
"github.com/nais/wonderwall/pkg/cookie"
|
|
"github.com/nais/wonderwall/pkg/crypto"
|
|
mw "github.com/nais/wonderwall/pkg/middleware"
|
|
"github.com/nais/wonderwall/pkg/openid"
|
|
openidclient "github.com/nais/wonderwall/pkg/openid/client"
|
|
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
|
|
retrypkg "github.com/nais/wonderwall/pkg/retry"
|
|
"github.com/nais/wonderwall/pkg/strings"
|
|
)
|
|
|
|
var (
|
|
ErrCookieNotFound = errors.New("session cookie not found")
|
|
ErrInvalidSession = errors.New("invalid session")
|
|
ErrInvalidIdpState = errors.New("invalid state at idp")
|
|
)
|
|
|
|
const (
|
|
refreshAcquireLockRetryInterval = 10 * time.Millisecond
|
|
refreshAcquireLockTimeout = 15 * time.Second
|
|
refreshLockDuration = 10 * time.Second
|
|
)
|
|
|
|
type Handler struct {
|
|
cfg *config.Config
|
|
client *openidclient.Client
|
|
crypter crypto.Crypter
|
|
openidCfg openidconfig.Config
|
|
store Store
|
|
}
|
|
|
|
func NewHandler(cfg *config.Config, openidCfg openidconfig.Config, crypter crypto.Crypter, openidClient *openidclient.Client) (*Handler, error) {
|
|
store, err := NewStore(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Handler{
|
|
crypter: crypter,
|
|
client: openidClient,
|
|
openidCfg: openidCfg,
|
|
store: store,
|
|
cfg: cfg,
|
|
}, nil
|
|
}
|
|
|
|
// Create creates and stores a session in the Store, and returns the session's key.
|
|
func (h *Handler) Create(r *http.Request, tokens *openid.Tokens, sessionLifetime time.Duration) (string, error) {
|
|
externalSessionID, err := h.IDOrGenerate(r, tokens)
|
|
if err != nil {
|
|
return "", fmt.Errorf("generating session ID: %w", err)
|
|
}
|
|
|
|
key := h.Key(externalSessionID)
|
|
tokenExpiresIn := time.Until(tokens.Expiry)
|
|
metadata := NewMetadata(tokenExpiresIn, sessionLifetime)
|
|
|
|
if h.cfg.Session.Inactivity {
|
|
metadata.WithTimeout(h.cfg.Session.InactivityTimeout)
|
|
}
|
|
|
|
encrypted, err := NewData(externalSessionID, tokens, metadata).Encrypt(h.crypter)
|
|
if err != nil {
|
|
return "", fmt.Errorf("encrypting session data: %w", err)
|
|
}
|
|
|
|
retryable := func(ctx context.Context) error {
|
|
err = h.store.Write(r.Context(), key, encrypted, sessionLifetime)
|
|
return retry.RetryableError(err)
|
|
}
|
|
|
|
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
|
|
return "", fmt.Errorf("writing to store: %w", err)
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
// Destroy destroys a session for a given session Key.
|
|
func (h *Handler) Destroy(r *http.Request, key string) error {
|
|
retryable := func(ctx context.Context) error {
|
|
err := h.store.Delete(r.Context(), key)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
if errors.Is(err, ErrKeyNotFound) {
|
|
return err
|
|
}
|
|
|
|
return retry.RetryableError(err)
|
|
}
|
|
|
|
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
|
|
return fmt.Errorf("deleting from store: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetAccessToken returns an access token from the session. If the token is empty or expired, an error is returned.
|
|
func (h *Handler) GetAccessToken(r *http.Request) (string, error) {
|
|
sessionData, err := h.GetOrRefresh(r)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if sessionData == nil {
|
|
return "", fmt.Errorf("%w: no session data", ErrInvalidSession)
|
|
}
|
|
|
|
if !sessionData.HasAccessToken() {
|
|
return "", fmt.Errorf("%w: no access token in session data", ErrInvalidSession)
|
|
}
|
|
|
|
if sessionData.Metadata.IsExpired() {
|
|
return "", fmt.Errorf("%w: access token is expired", ErrInvalidSession)
|
|
}
|
|
|
|
return sessionData.AccessToken, nil
|
|
}
|
|
|
|
// Get returns the session data for a given session Key.
|
|
func (h *Handler) Get(r *http.Request, key string) (*Data, error) {
|
|
var encryptedSessionData *EncryptedData
|
|
var err error
|
|
|
|
retryable := func(ctx context.Context) error {
|
|
encryptedSessionData, err = h.store.Read(ctx, key)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
if errors.Is(err, ErrKeyNotFound) {
|
|
return err
|
|
}
|
|
|
|
return retry.RetryableError(err)
|
|
}
|
|
|
|
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
|
|
return nil, fmt.Errorf("reading from store: %w", err)
|
|
}
|
|
|
|
sessionData, err := encryptedSessionData.Decrypt(h.crypter)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("decrypting session data: %w", err)
|
|
}
|
|
|
|
return sessionData, nil
|
|
}
|
|
|
|
// GetKey extracts the session Key from the session cookie found in the request, if any.
|
|
func (h *Handler) GetKey(r *http.Request) (string, error) {
|
|
key, err := cookie.GetDecrypted(r, cookie.Session, h.crypter)
|
|
if errors.Is(err, http.ErrNoCookie) {
|
|
return "", ErrCookieNotFound
|
|
}
|
|
if errors.Is(err, cookie.ErrInvalidValue) {
|
|
return "", err
|
|
}
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return key, nil
|
|
}
|
|
|
|
// GetOrRefresh returns the session data, performing refreshes if enabled and necessary.
|
|
func (h *Handler) GetOrRefresh(r *http.Request) (*Data, error) {
|
|
key, err := h.GetKey(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
sessionData, err := h.Get(r, key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if h.isTimedOut(sessionData) {
|
|
return nil, fmt.Errorf("%w: session is inactive", ErrInvalidSession)
|
|
}
|
|
|
|
if !h.shouldRefresh(sessionData) {
|
|
return sessionData, nil
|
|
}
|
|
|
|
refreshed, err := h.Refresh(r, key, sessionData)
|
|
if errors.Is(err, ErrInvalidIdpState) || errors.Is(err, ErrInvalidSession) {
|
|
return nil, err
|
|
} else if err != nil {
|
|
mw.LogEntryFrom(r).Warnf("session: could not refresh tokens; falling back to existing token: %+v", err)
|
|
} else {
|
|
sessionData = refreshed
|
|
}
|
|
|
|
return sessionData, nil
|
|
}
|
|
|
|
// IDOrGenerate returns the session ID, derived from the given request or id_token; e.g. `sid` or `session_state`.
|
|
// If none are present, a generated ID is returned.
|
|
func (h *Handler) IDOrGenerate(r *http.Request, tokens *openid.Tokens) (string, error) {
|
|
return NewSessionID(h.openidCfg.Provider(), tokens.IDToken, r.URL.Query())
|
|
}
|
|
|
|
// Key prefixes the session ID, e.g. the `sid` or the `session_state` properties from the OpenID provider to prevent key
|
|
// collisions in the session Store.
|
|
//
|
|
// `sid` or `session_state` is a key that refers to the user's unique SSO session at the OpenID Provider.
|
|
// The same key is present in all tokens acquired by any Relying Party during that session. Thus, we cannot assume that
|
|
// the value of `sid` or `session_state` to uniquely identify the pair of (user, application session) if using a shared
|
|
// session store across multiple Relying Parties.
|
|
func (h *Handler) Key(sessionID string) string {
|
|
client := h.openidCfg.Client()
|
|
|
|
return fmt.Sprintf("%s:%s:%s", h.cfg.OpenID.Provider, client.ClientID(), sessionID)
|
|
}
|
|
|
|
// Refresh refreshes the user's session and returns the updated session data.
|
|
func (h *Handler) Refresh(r *http.Request, key string, data *Data) (*Data, error) {
|
|
if !h.canRefresh(data) {
|
|
return data, nil
|
|
}
|
|
|
|
logger := mw.LogEntryFrom(r)
|
|
logger.Debug("session: initiating refresh attempt...")
|
|
|
|
ctx := r.Context()
|
|
lock := h.store.MakeLock(key)
|
|
|
|
logger.Debug("session: acquiring lock...")
|
|
err := func() error {
|
|
timeout := time.NewTimer(refreshAcquireLockTimeout)
|
|
defer timeout.Stop()
|
|
|
|
ticker := time.NewTicker(refreshAcquireLockRetryInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("context done: %w", ctx.Err())
|
|
case <-timeout.C:
|
|
return fmt.Errorf("timed out")
|
|
case <-ticker.C:
|
|
err := lock.Acquire(ctx, refreshLockDuration)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
if !errors.Is(err, ErrAcquireLock) {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("while acquiring lock: %w", err)
|
|
}
|
|
defer func(lock Lock, ctx context.Context) {
|
|
err := lock.Release(ctx)
|
|
if err != nil {
|
|
logger.Warnf("session: releasing lock: %+v", err)
|
|
}
|
|
}(lock, ctx)
|
|
|
|
// Get the latest session state again in case it was changed while acquiring the lock
|
|
data, err = h.Get(r, key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !h.canRefresh(data) {
|
|
logger.Debug("session: already refreshed, aborting refresh attempt.")
|
|
return data, nil
|
|
}
|
|
|
|
if h.isTimedOut(data) {
|
|
return nil, fmt.Errorf("%w: session is inactive", ErrInvalidSession)
|
|
}
|
|
|
|
logger.Debug("session: performing refresh grant...")
|
|
var resp *openid.TokenResponse
|
|
refresh := func(ctx context.Context) error {
|
|
resp, err = h.client.RefreshGrant(ctx, data.RefreshToken)
|
|
if errors.Is(err, openidclient.ErrOpenIDServer) {
|
|
return retry.RetryableError(err)
|
|
}
|
|
|
|
return err
|
|
}
|
|
if err := retry.Do(ctx, retrypkg.DefaultBackoff, refresh); err != nil {
|
|
if errors.Is(err, openidclient.ErrOpenIDClient) {
|
|
return nil, fmt.Errorf("%w: authorization might be invalid: %+v", ErrInvalidIdpState, err)
|
|
}
|
|
return nil, fmt.Errorf("performing refresh: %w", err)
|
|
}
|
|
|
|
data.AccessToken = resp.AccessToken
|
|
data.RefreshToken = resp.RefreshToken
|
|
data.Metadata.Refresh(resp.ExpiresIn)
|
|
|
|
if h.cfg.Session.Inactivity {
|
|
data.Metadata.ExtendTimeout(h.cfg.Session.InactivityTimeout)
|
|
}
|
|
|
|
err = h.Update(ctx, key, data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
logger.Info("session: successfully refreshed")
|
|
return data, nil
|
|
}
|
|
|
|
func (h *Handler) Update(ctx context.Context, key string, data *Data) error {
|
|
encrypted, err := data.Encrypt(h.crypter)
|
|
if err != nil {
|
|
return fmt.Errorf("encrypting session data: %w", err)
|
|
}
|
|
|
|
update := func(ctx context.Context) error {
|
|
err = h.store.Update(ctx, key, encrypted)
|
|
if errors.Is(err, ErrKeyNotFound) {
|
|
return err
|
|
}
|
|
return retry.RetryableError(err)
|
|
}
|
|
|
|
if err := retry.Do(ctx, retrypkg.DefaultBackoff, update); err != nil {
|
|
return fmt.Errorf("updating in store: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (h *Handler) canRefresh(data *Data) bool {
|
|
return h.cfg.Session.Refresh && data.HasRefreshToken() && !data.Metadata.IsRefreshOnCooldown()
|
|
}
|
|
|
|
func (h *Handler) shouldRefresh(data *Data) bool {
|
|
return h.cfg.Session.Refresh && data.HasRefreshToken() && data.Metadata.ShouldRefresh()
|
|
}
|
|
|
|
func (h *Handler) isTimedOut(data *Data) bool {
|
|
return h.cfg.Session.Inactivity && data.Metadata.IsTimedOut()
|
|
}
|
|
|
|
func NewSessionID(cfg openidconfig.Provider, idToken *openid.IDToken, params url.Values) (string, error) {
|
|
// 1. check for 'sid' claim in id_token
|
|
sessionID, err := idToken.GetSidClaim()
|
|
if err == nil {
|
|
return sessionID, nil
|
|
}
|
|
// 1a. error if sid claim is required according to openid config
|
|
if err != nil && cfg.SidClaimRequired() {
|
|
return "", err
|
|
}
|
|
|
|
// 2. check for session_state in callback params
|
|
sessionID, err = getSessionStateFrom(params)
|
|
if err == nil {
|
|
return sessionID, nil
|
|
}
|
|
// 2a. error if session_state is required according to openid config
|
|
if err != nil && cfg.SessionStateRequired() {
|
|
return "", err
|
|
}
|
|
|
|
// 3. generate ID if all else fails
|
|
sessionID, err = strings.GenerateBase64(64)
|
|
if err != nil {
|
|
return "", fmt.Errorf("generating session ID: %w", err)
|
|
}
|
|
return sessionID, nil
|
|
}
|
|
|
|
func getSessionStateFrom(params url.Values) (string, error) {
|
|
sessionState := params.Get(openid.SessionState)
|
|
if len(sessionState) == 0 {
|
|
return "", fmt.Errorf("missing required '%s' in params", openid.SessionState)
|
|
}
|
|
return sessionState, nil
|
|
}
|