package session import ( "context" "errors" "fmt" "net/http" "net/url" "time" "github.com/sethvargo/go-retry" "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" "github.com/nais/wonderwall/pkg/crypto" mw "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/openid" openidclient "github.com/nais/wonderwall/pkg/openid/client" openidconfig "github.com/nais/wonderwall/pkg/openid/config" retrypkg "github.com/nais/wonderwall/pkg/retry" "github.com/nais/wonderwall/pkg/strings" ) var ( CookieNotFoundError = errors.New("cookie not found") NoSessionDataError = errors.New("no session data") NoAccessTokenError = errors.New("no access token in session data") ExpiredAccessTokenError = errors.New("access token is expired") ) type Handler struct { client openidclient.Client crypter crypto.Crypter openidCfg openidconfig.Config refreshEnabled bool metadataRolloutEnabled bool store Store } func NewHandler(cfg *config.Config, openidCfg openidconfig.Config, crypter crypto.Crypter, openidClient openidclient.Client) (*Handler, error) { store, err := NewStore(cfg) if err != nil { return nil, err } return &Handler{ crypter: crypter, client: openidClient, openidCfg: openidCfg, store: store, refreshEnabled: cfg.Session.Refresh, metadataRolloutEnabled: cfg.Session.MetadataRollout, }, nil } // Create creates and stores a session in the Store, and returns the session's key. func (h *Handler) Create(r *http.Request, tokens *openid.Tokens, sessionLifetime time.Duration) (string, error) { externalSessionID, err := h.IDOrGenerate(r, tokens) if err != nil { return "", fmt.Errorf("generating session ID: %w", err) } key := h.Key(externalSessionID) tokenExpiresIn := tokens.Expiry.Sub(time.Now()) metadata := NewMetadata(tokenExpiresIn, sessionLifetime) encrypted, err := NewData(externalSessionID, tokens, metadata).Encrypt(h.crypter) if err != nil { return "", fmt.Errorf("encrypting session data: %w", err) } retryable := func(ctx context.Context) error { err = h.store.Write(r.Context(), key, encrypted, sessionLifetime) return retry.RetryableError(err) } if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { return "", fmt.Errorf("writing to store: %w", err) } return key, nil } // DestroyForID destroys a session for a given session ID. Note that a session ID is not equal to a session Key. func (h *Handler) DestroyForID(r *http.Request, id string) error { key := h.Key(id) return h.destroyForKey(r, key) } func (h *Handler) destroyForKey(r *http.Request, key string) error { retryable := func(ctx context.Context) error { err := h.store.Delete(r.Context(), key) if err == nil { return nil } if errors.Is(err, KeyNotFoundError) { return err } return retry.RetryableError(err) } if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { return fmt.Errorf("deleting from store: %w", err) } return nil } // Get returns the session data for a given http.Request, matching by the session cookie. func (h *Handler) Get(r *http.Request) (*Data, error) { key, err := h.GetKey(r) if err != nil { return nil, fmt.Errorf("no session cookie: %w", err) } return h.GetForKey(r, key) } // GetAccessToken returns an access token from the session. If the token is empty or expired, an error is returned. func (h *Handler) GetAccessToken(r *http.Request) (string, error) { sessionData, err := h.GetOrRefresh(r) if err != nil { return "", err } if sessionData == nil { return "", NoSessionDataError } if !sessionData.HasAccessToken() { return "", NoAccessTokenError } if h.metadataRolloutEnabled && sessionData.Metadata.IsExpired() { return "", ExpiredAccessTokenError } return sessionData.AccessToken, nil } // GetForID returns the session data for a given session ID. func (h *Handler) GetForID(r *http.Request, id string) (*Data, error) { key := h.Key(id) return h.GetForKey(r, key) } // GetForKey returns the session data for a given session Key. func (h *Handler) GetForKey(r *http.Request, key string) (*Data, error) { var encryptedSessionData *EncryptedData var err error retryable := func(ctx context.Context) error { encryptedSessionData, err = h.store.Read(ctx, key) if err == nil { return nil } if errors.Is(err, KeyNotFoundError) { return err } return retry.RetryableError(err) } if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { return nil, fmt.Errorf("reading from store: %w", err) } sessionData, err := encryptedSessionData.Decrypt(h.crypter) if err != nil { return nil, fmt.Errorf("decrypting session data: %w", err) } return sessionData, nil } // GetKey extracts the session Key from the session cookie found in the request, if any. func (h *Handler) GetKey(r *http.Request) (string, error) { key, err := cookie.GetDecrypted(r, cookie.Session, h.crypter) if err != nil { return "", fmt.Errorf("%w: %+v", CookieNotFoundError, err) } return key, nil } // GetOrRefresh returns the session data, performing refreshes if enabled and necessary. func (h *Handler) GetOrRefresh(r *http.Request) (*Data, error) { key, err := h.GetKey(r) if err != nil { return nil, err } sessionData, err := h.GetForKey(r, key) if err != nil { return nil, err } if !h.shouldRefresh(sessionData) { return sessionData, nil } refreshed, err := h.Refresh(r, key, sessionData) if err != nil { mw.LogEntryFrom(r).Warnf("session: could not refresh tokens, falling back to existing token: %+v", err) } else { sessionData = refreshed } return sessionData, nil } // IDOrGenerate returns the session ID, derived from the given request or id_token; e.g. `sid` or `session_state`. // If none are present, a generated ID is returned. func (h *Handler) IDOrGenerate(r *http.Request, tokens *openid.Tokens) (string, error) { return NewSessionID(h.openidCfg.Provider(), tokens.IDToken, r.URL.Query()) } // Key prefixes the session ID, e.g. the `sid` or the `session_state` properties from the OpenID provider to prevent key // collisions in the session Store. // // `sid` or `session_state` is a key that refers to the user's unique SSO session at the OpenID Provider. // The same key is present in all tokens acquired by any Relying Party during that session. Thus, we cannot assume that // the value of `sid` or `session_state` to uniquely identify the pair of (user, application session) if using a shared // session store across multiple Relying Parties. func (h *Handler) Key(sessionID string) string { provider := h.openidCfg.Provider() client := h.openidCfg.Client() return fmt.Sprintf("%s:%s:%s", provider.Name(), client.ClientID(), sessionID) } // Refresh refreshes the user's session and returns the updated session data. func (h *Handler) Refresh(r *http.Request, key string, data *Data) (*Data, error) { if !h.canRefresh(data) { return data, nil } logger := mw.LogEntryFrom(r) logger.Info("session: refreshing token...") var resp *openid.TokenResponse var err error refresh := func(ctx context.Context) error { resp, err = h.client.RefreshGrant(r.Context(), data.RefreshToken) return retry.RetryableError(err) } if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, refresh); err != nil { return nil, fmt.Errorf("performing refresh: %w", err) } data.AccessToken = resp.AccessToken data.RefreshToken = resp.RefreshToken data.Metadata.Refresh(resp.ExpiresIn) encrypted, err := data.Encrypt(h.crypter) if err != nil { return nil, fmt.Errorf("encrypting session data: %w", err) } update := func(ctx context.Context) error { err = h.store.Update(r.Context(), key, encrypted) return retry.RetryableError(err) } if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, update); err != nil { return nil, fmt.Errorf("updating in store: %w", err) } logger.Info("session: successfully refreshed") return data, nil } func (h *Handler) canRefresh(data *Data) bool { return h.refreshEnabled && data.HasRefreshToken() && !data.Metadata.IsRefreshOnCooldown() } func (h *Handler) shouldRefresh(data *Data) bool { return h.refreshEnabled && data.HasRefreshToken() && data.Metadata.ShouldRefresh() } func NewSessionID(cfg openidconfig.Provider, idToken *openid.IDToken, params url.Values) (string, error) { // 1. check for 'sid' claim in id_token sessionID, err := idToken.GetSidClaim() if err == nil { return sessionID, nil } // 1a. error if sid claim is required according to openid config if err != nil && cfg.SidClaimRequired() { return "", err } // 2. check for session_state in callback params sessionID, err = getSessionStateFrom(params) if err == nil { return sessionID, nil } // 2a. error if session_state is required according to openid config if err != nil && cfg.SessionStateRequired() { return "", err } // 3. generate ID if all else fails sessionID, err = strings.GenerateBase64(64) if err != nil { return "", fmt.Errorf("generating session ID: %w", err) } return sessionID, nil } func getSessionStateFrom(params url.Values) (string, error) { sessionState := params.Get(openid.SessionState) if len(sessionState) == 0 { return "", fmt.Errorf("missing required '%s' in params", openid.SessionState) } return sessionState, nil }