mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-10 10:27:02 +00:00
These feature flags were enabled by default. We specifically disallowed the use of automatic refresh with the SSO mode, though this poses some complexity if using the forward-auth feature. To simplify configuration and code, we remove the flags in their entirety as session refresh behaviour is mostly already handled by the implementation of GetSession() in the handlers. Specifically: - the Standalone handler needs to refresh sessions when reverse-proxying to the upstream. - the SSO server handler needs to refresh sessions only when using the forward-auth feature. It does not have an upstream to reverse proxy to. - the SSO proxy handler is a read-only upstream proxy and does not possess the ability to refresh sessions itself, though it will delegate traffic for the session endpoints to the configured SSO server. Automatic refreshing is thus only disabled when running in SSO mode without the forward-auth feature.
264 lines
6.6 KiB
Go
264 lines
6.6 KiB
Go
package session
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/nais/wonderwall/pkg/config"
|
|
"github.com/nais/wonderwall/pkg/crypto"
|
|
mw "github.com/nais/wonderwall/pkg/middleware"
|
|
"github.com/nais/wonderwall/pkg/openid"
|
|
openidclient "github.com/nais/wonderwall/pkg/openid/client"
|
|
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
|
|
"github.com/nais/wonderwall/pkg/retry"
|
|
)
|
|
|
|
const (
|
|
refreshAcquireLockRetryInterval = 10 * time.Millisecond
|
|
refreshAcquireLockTimeout = 15 * time.Second
|
|
refreshLockDuration = 10 * time.Second
|
|
)
|
|
|
|
var _ Manager = &manager{}
|
|
|
|
type manager struct {
|
|
*reader
|
|
cfg *config.Config
|
|
client *openidclient.Client
|
|
openidCfg openidconfig.Config
|
|
store Store
|
|
}
|
|
|
|
func NewManager(cfg *config.Config, openidCfg openidconfig.Config, crypter crypto.Crypter, openidClient *openidclient.Client) (Manager, error) {
|
|
store, err := NewStore(cfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rd := &reader{
|
|
cfg: cfg,
|
|
cookieCrypter: crypter,
|
|
store: store,
|
|
}
|
|
|
|
return &manager{
|
|
reader: rd,
|
|
cfg: cfg,
|
|
client: openidClient,
|
|
openidCfg: openidCfg,
|
|
store: store,
|
|
}, nil
|
|
}
|
|
|
|
func (in *manager) Create(r *http.Request, tokens *openid.Tokens, sessionLifetime time.Duration) (*Session, error) {
|
|
externalSessionID, err := ExternalID(r, in.openidCfg.Provider(), tokens.IDToken)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generating session ID: %w", err)
|
|
}
|
|
|
|
key := in.key(externalSessionID)
|
|
tokenExpiresIn := time.Until(tokens.Expiry)
|
|
metadata := NewMetadata(tokenExpiresIn, sessionLifetime)
|
|
|
|
if in.cfg.Session.Inactivity {
|
|
metadata.WithTimeout(in.cfg.Session.InactivityTimeout)
|
|
}
|
|
|
|
ticket, err := NewTicket(key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("making ticket: %w", err)
|
|
}
|
|
|
|
data := NewData(externalSessionID, tokens, metadata)
|
|
|
|
encrypted, err := data.Encrypt(ticket.Crypter())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("encrypting session data: %w", err)
|
|
}
|
|
|
|
if err := retry.Do(r.Context(), func(ctx context.Context) error {
|
|
err = in.store.Write(r.Context(), key, encrypted, sessionLifetime)
|
|
return retry.RetryableError(err)
|
|
}); err != nil {
|
|
return nil, fmt.Errorf("writing to store: %w", err)
|
|
}
|
|
|
|
return NewSession(data, ticket), nil
|
|
}
|
|
|
|
func (in *manager) Delete(ctx context.Context, session *Session) error {
|
|
return in.deleteForKey(ctx, session.key())
|
|
}
|
|
|
|
func (in *manager) DeleteForExternalID(ctx context.Context, id string) error {
|
|
key := in.key(id)
|
|
return in.deleteForKey(ctx, key)
|
|
}
|
|
|
|
func (in *manager) GetOrRefresh(r *http.Request) (*Session, error) {
|
|
sess, err := in.Get(r)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("getting session: %w", err)
|
|
}
|
|
|
|
if !sess.shouldRefresh() {
|
|
return sess, nil
|
|
}
|
|
|
|
refreshed, err := in.Refresh(r, sess)
|
|
if err == nil {
|
|
return refreshed, nil
|
|
}
|
|
|
|
if errors.Is(err, ErrInvalidExternal) || errors.Is(err, ErrInvalid) {
|
|
return nil, err
|
|
}
|
|
|
|
if !errors.Is(err, context.Canceled) {
|
|
mw.LogEntryFrom(r).Warnf("session: could not refresh tokens; falling back to existing tokens: %+v", err)
|
|
}
|
|
|
|
return sess, nil
|
|
}
|
|
|
|
func (in *manager) Refresh(r *http.Request, sess *Session) (*Session, error) {
|
|
if !sess.canRefresh() {
|
|
return sess, nil
|
|
}
|
|
|
|
logger := mw.LogEntryFrom(r).WithField("sid", sess.ExternalSessionID())
|
|
logger.Debug("session: initiating refresh attempt...")
|
|
|
|
ctx := r.Context()
|
|
lock := in.store.MakeLock(sess.ticket.Key())
|
|
|
|
logger.Debug("session: acquiring lock...")
|
|
err := func() error {
|
|
timeout := time.NewTimer(refreshAcquireLockTimeout)
|
|
defer timeout.Stop()
|
|
|
|
ticker := time.NewTicker(refreshAcquireLockRetryInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return fmt.Errorf("context done: %w", ctx.Err())
|
|
case <-timeout.C:
|
|
return fmt.Errorf("timed out")
|
|
case <-ticker.C:
|
|
err := lock.Acquire(ctx, refreshLockDuration)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
if !errors.Is(err, ErrAcquireLock) {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("while acquiring lock: %w", err)
|
|
}
|
|
defer func(lock Lock, ctx context.Context) {
|
|
err := lock.Release(ctx)
|
|
if err != nil && !errors.Is(err, context.Canceled) {
|
|
logger.Warnf("session: releasing lock: %+v", err)
|
|
}
|
|
}(lock, ctx)
|
|
|
|
// Get the latest session state again in case it was changed while acquiring the lock
|
|
sess, err = in.getForTicket(ctx, sess.ticket)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !sess.canRefresh() {
|
|
logger.Debug("session: already refreshed, aborting refresh attempt.")
|
|
return sess, nil
|
|
}
|
|
|
|
logger.Debug("session: performing refresh grant...")
|
|
resp, err := retry.DoValue(ctx, func(ctx context.Context) (*openid.TokenResponse, error) {
|
|
resp, err := in.client.RefreshGrant(ctx, sess.data.RefreshToken)
|
|
if errors.Is(err, openidclient.ErrOpenIDServer) {
|
|
return nil, retry.RetryableError(err)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return resp, nil
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, openidclient.ErrOpenIDClient) {
|
|
return nil, fmt.Errorf("%w: authorization might be invalid: %+v", ErrInvalidExternal, err)
|
|
}
|
|
return nil, fmt.Errorf("performing refresh: %w", err)
|
|
}
|
|
|
|
sess.data.AccessToken = resp.AccessToken
|
|
sess.data.RefreshToken = resp.RefreshToken
|
|
sess.data.Metadata.Refresh(resp.ExpiresIn)
|
|
|
|
if in.cfg.Session.Inactivity {
|
|
sess.data.Metadata.WithTimeout(in.cfg.Session.InactivityTimeout)
|
|
}
|
|
|
|
err = in.update(ctx, sess)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
logger.Info("session: successfully refreshed")
|
|
return sess, nil
|
|
}
|
|
|
|
func (in *manager) deleteForKey(ctx context.Context, key string) error {
|
|
if err := retry.Do(ctx, func(ctx context.Context) error {
|
|
err := in.store.Delete(ctx, key)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
if errors.Is(err, ErrNotFound) {
|
|
return err
|
|
}
|
|
|
|
return retry.RetryableError(err)
|
|
}); err != nil {
|
|
return fmt.Errorf("deleting from store: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// key constructs a session key given an external session ID, e.g. the `sid` or the `session_state` properties from the OpenID Connect auth code flow.
|
|
func (in *manager) key(externalSessionID string) string {
|
|
clientID := in.openidCfg.Client().ClientID()
|
|
providerName := in.cfg.OpenID.Provider
|
|
return fmt.Sprintf("%s:%s:%s", providerName, clientID, externalSessionID)
|
|
}
|
|
|
|
func (in *manager) update(ctx context.Context, sess *Session) error {
|
|
encrypted, err := sess.encrypt()
|
|
if err != nil {
|
|
return fmt.Errorf("encrypting session data: %w", err)
|
|
}
|
|
|
|
if err := retry.Do(ctx, func(ctx context.Context) error {
|
|
err = in.store.Update(ctx, sess.ticket.Key(), encrypted)
|
|
if errors.Is(err, ErrNotFound) {
|
|
return err
|
|
}
|
|
return retry.RetryableError(err)
|
|
}); err != nil {
|
|
return fmt.Errorf("updating in store: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|