diff --git a/README.md b/README.md index c75e83e..940eef5 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,8 @@ The following flags are available: --redis.password string Password for Redis. --redis.tls Whether or not to use TLS for connecting to Redis. (default true) --redis.username string Username for Redis. +--session.inactivity Automatically expire user sessions if they have not refreshed their tokens within a given duration. +--session.inactivity-timeout duration Inactivity timeout for user sessions. (default 30m0s) --session.max-lifetime duration Max lifetime for user sessions. (default 1h0m0s) --session.refresh Automatically refresh the tokens for user sessions if they are expired, as long as the session exists (indicated by the session max lifetime). --upstream-host string Address of upstream host. (default "127.0.0.1:8080") @@ -177,7 +179,10 @@ Otherwise, an `HTTP 200 OK` is returned with the metadata with the `application/ "session": { "created_at": "2022-08-31T06:58:38.724717899Z", "ends_at": "2022-08-31T16:58:38.724717899Z", - "ends_in_seconds": 14658 + "timeout_at": "0001-01-01T00:00:00Z", + "ends_in_seconds": 14658, + "active": true, + "timeout_in_seconds": -1 }, "tokens": { "expire_at": "2022-08-31T14:03:47.318251953Z", @@ -189,14 +194,17 @@ Otherwise, an `HTTP 200 OK` is returned with the metadata with the `application/ Most of these fields should be self-explanatory, but we'll be explicit with their description: -| Field | Description | -|---------------------------------------|-------------------------------------------------------------------------------------------------| -| `session.created_at` | The timestamp that denotes when the session was first created. | -| `session.ends_at` | The timestamp that denotes when the session will end. | -| `session.ends_in_seconds` | The number of seconds until the session ends. | -| `tokens.expire_at` | The timestamp that denotes when the tokens within the session will expire. | -| `tokens.refreshed_at` | The timestamp that denotes when the tokens within the session was last refreshed. | -| `tokens.expire_in_seconds` | The number of seconds until the tokens expire. | +| Field | Description | +|------------------------------|----------------------------------------------------------------------------------------------------------------------| +| `session.created_at` | The timestamp that denotes when the session was first created. | +| `session.ends_at` | The timestamp that denotes when the session will end. | +| `session.timeout_at` | The timestamp that denotes when the session will time out. The zero-value, `0001-01-01T00:00:00Z`, means no timeout. | +| `session.ends_in_seconds` | The number of seconds until the session ends. | +| `session.active` | Whether or not the session is marked as active. | +| `session.timeout_in_seconds` | The number of seconds until the session times out. A value of `-1` means no timeout. | +| `tokens.expire_at` | The timestamp that denotes when the tokens within the session will expire. | +| `tokens.refreshed_at` | The timestamp that denotes when the tokens within the session was last refreshed. | +| `tokens.expire_in_seconds` | The number of seconds until the tokens expire. | ### Refresh Tokens @@ -243,6 +251,20 @@ Note that the refresh operation has a default cooldown period of 1 minute, which of the tokens returned by the identity provider. In other words, a request to the `/oauth2/session/refresh` endpoint will only trigger a refresh if `tokens.refresh_cooldown` is `false`. +### Inactivity + +A session can be marked as inactive if the time since last refresh exceeds a given timeout. This is useful if you want +to ensure that an end-user can re-authenticate with the identity provider if they've been gone from an authenticated +session for some time. + +This is enabled with the `session.inactivity` option, which also requires `session.refresh`. + +The `/oauth2/session` endpoint returns `session.active`, `session.timeout_at` and `session.timeout_in_seconds` that +indicates the state of the session and when it times out. + +The timeout is configured with `session.inactivity-timeout`. If this timeout is shorter than the token lifetime, you +should implement mechanisms to trigger refreshes before the timeout is reached. + ## Development ### Requirements diff --git a/pkg/config/config.go b/pkg/config/config.go index bfe8b8c..5474133 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "time" "github.com/nais/liberator/pkg/conftools" @@ -40,8 +41,10 @@ type Loginstatus struct { } type Session struct { - MaxLifetime time.Duration `json:"max-lifetime"` - Refresh bool `json:"refresh"` + Inactivity bool `json:"inactivity"` + InactivityTimeout time.Duration `json:"inactivity-timeout"` + MaxLifetime time.Duration `json:"max-lifetime"` + Refresh bool `json:"refresh"` } const ( @@ -57,8 +60,10 @@ const ( Ingress = "ingress" UpstreamHost = "upstream-host" - SessionMaxLifetime = "session.max-lifetime" - SessionRefresh = "session.refresh" + SessionInactivity = "session.inactivity" + SessionInactivityTimeout = "session.inactivity-timeout" + SessionMaxLifetime = "session.max-lifetime" + SessionRefresh = "session.refresh" LoginstatusEnabled = "loginstatus.enabled" LoginstatusCookieDomain = "loginstatus.cookie-domain" @@ -82,6 +87,8 @@ func Initialize() (*Config, error) { flag.StringSlice(Ingress, []string{}, "Comma separated list of ingresses used to access the main application.") flag.String(UpstreamHost, "127.0.0.1:8080", "Address of upstream host.") + flag.Bool(SessionInactivity, false, "Automatically expire user sessions if they have not refreshed their tokens within a given duration.") + flag.Duration(SessionInactivityTimeout, 30*time.Minute, "Inactivity timeout for user sessions.") flag.Duration(SessionMaxLifetime, time.Hour, "Max lifetime for user sessions.") flag.Bool(SessionRefresh, false, "Automatically refresh the tokens for user sessions if they are expired, as long as the session exists (indicated by the session max lifetime).") @@ -137,5 +144,18 @@ func Initialize() (*Config, error) { log.WithField("logger", "wonderwall.config").Info(line) } + err := cfg.Validate() + if err != nil { + return nil, fmt.Errorf("validating config: %w", err) + } + return cfg, nil } + +func (c *Config) Validate() error { + if c.Session.Inactivity && !c.Session.Refresh { + return fmt.Errorf("%q cannot be enabled without %q", SessionInactivity, SessionRefresh) + } + + return nil +} diff --git a/pkg/handler/api/session/session.go b/pkg/handler/api/session/session.go index 67556a3..d682971 100644 --- a/pkg/handler/api/session/session.go +++ b/pkg/handler/api/session/session.go @@ -20,15 +20,18 @@ func Handler(src Source, w http.ResponseWriter, r *http.Request) { data, err := src.GetSessions().Get(r) if err != nil { - if errors.Is(err, session.ErrCookieNotFound) || errors.Is(err, session.ErrKeyNotFound) { + switch { + case errors.Is(err, session.ErrCookieNotFound), errors.Is(err, session.ErrKeyNotFound): logger.Infof("session/info: getting session: %+v", err) w.WriteHeader(http.StatusUnauthorized) return + case errors.Is(err, session.ErrSessionInactive): + // do nothing; we want to return metadata even if the session is inactive + default: + logger.Warnf("session/info: getting session: %+v", err) + w.WriteHeader(http.StatusInternalServerError) + return } - - logger.Warnf("session/info: getting session: %+v", err) - w.WriteHeader(http.StatusInternalServerError) - return } w.Header().Set("Content-Type", "application/json") diff --git a/pkg/handler/api/sessionrefresh/sessionrefresh.go b/pkg/handler/api/sessionrefresh/sessionrefresh.go index cf0becb..dcfe70a 100644 --- a/pkg/handler/api/sessionrefresh/sessionrefresh.go +++ b/pkg/handler/api/sessionrefresh/sessionrefresh.go @@ -25,14 +25,14 @@ func Handler(src Source, w http.ResponseWriter, r *http.Request) { data, err := src.GetSessions().Get(r) if err != nil { - if errors.Is(err, session.ErrKeyNotFound) { + switch { + case errors.Is(err, session.ErrKeyNotFound), errors.Is(err, session.ErrSessionInactive): logger.Infof("session/refresh: getting session: %+v", err) w.WriteHeader(http.StatusUnauthorized) - return + default: + logger.Warnf("session/refresh: getting session: %+v", err) + w.WriteHeader(http.StatusInternalServerError) } - - logger.Warnf("session/refresh: getting session: %+v", err) - w.WriteHeader(http.StatusInternalServerError) return } diff --git a/pkg/handler/handler_test.go b/pkg/handler/handler_test.go index 5858538..54e9de2 100644 --- a/pkg/handler/handler_test.go +++ b/pkg/handler/handler_test.go @@ -188,6 +188,41 @@ func TestHandler_SessionInfo(t *testing.T) { // 1 second < time until token expires <= max duration for tokens from IDP assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration) assert.Greater(t, tokenExpiryDuration, time.Second) + + assert.True(t, data.Session.Active) + assert.True(t, data.Session.TimeoutAt.IsZero()) + assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds) +} + +func TestHandler_SessionInfo_WithInactivity(t *testing.T) { + cfg := mock.Config() + cfg.Session.Refresh = true + cfg.Session.Inactivity = true + cfg.Session.InactivityTimeout = 10 * time.Minute + + idp := mock.NewIdentityProvider(cfg) + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + resp := sessionInfo(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var data session.MetadataVerbose + err := json.Unmarshal([]byte(resp.Body), &data) + assert.NoError(t, err) + + maxDelta := 5 * time.Second + + assert.True(t, data.Session.Active) + assert.False(t, data.Session.TimeoutAt.IsZero()) + + expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout) + assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta) + + actualTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second + assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(actualTimeoutDuration), maxDelta) } func TestHandler_SessionInfo_WithRefresh(t *testing.T) { @@ -232,6 +267,10 @@ func TestHandler_SessionInfo_WithRefresh(t *testing.T) { // 1 second < refresh cooldown <= minimum refresh interval assert.LessOrEqual(t, data.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval) assert.Greater(t, data.Tokens.RefreshCooldownSeconds, int64(1)) + + assert.True(t, data.Session.Active) + assert.True(t, data.Session.TimeoutAt.IsZero()) + assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds) } func TestHandler_SessionRefresh(t *testing.T) { @@ -254,27 +293,7 @@ func TestHandler_SessionRefresh(t *testing.T) { assert.NoError(t, err) // wait until refresh cooldown has reached zero before refresh - func() { - timeout := time.After(5 * time.Second) - ticker := time.Tick(500 * time.Millisecond) - for { - select { - case <-timeout: - assert.Fail(t, "refresh cooldown timer exceeded timeout") - case <-ticker: - resp := sessionInfo(t, idp, rpClient) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - var temp session.MetadataVerboseWithRefresh - err = json.Unmarshal([]byte(resp.Body), &temp) - assert.NoError(t, err) - - if !temp.Tokens.RefreshCooldown { - return - } - } - } - }() + waitForRefreshCooldownTimer(t, idp, rpClient) resp = sessionRefresh(t, idp, rpClient) assert.Equal(t, http.StatusOK, resp.StatusCode) @@ -313,6 +332,15 @@ func TestHandler_SessionRefresh(t *testing.T) { // 1 second < refresh cooldown <= minimum refresh interval assert.LessOrEqual(t, refreshedData.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval) assert.Greater(t, refreshedData.Tokens.RefreshCooldownSeconds, int64(1)) + + assert.True(t, data.Session.Active) + assert.True(t, refreshedData.Session.Active) + + assert.True(t, data.Session.TimeoutAt.IsZero()) + assert.True(t, refreshedData.Session.TimeoutAt.IsZero()) + + assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds) + assert.Equal(t, int64(-1), refreshedData.Session.TimeoutInSeconds) } func TestHandler_SessionRefresh_Disabled(t *testing.T) { @@ -330,6 +358,58 @@ func TestHandler_SessionRefresh_Disabled(t *testing.T) { assert.Equal(t, http.StatusNotFound, resp.StatusCode) } +func TestHandler_SessionRefresh_WithInactivity(t *testing.T) { + cfg := mock.Config() + cfg.Session.Refresh = true + cfg.Session.Inactivity = true + cfg.Session.InactivityTimeout = 10 * time.Minute + + idp := mock.NewIdentityProvider(cfg) + idp.ProviderHandler.TokenDuration = 5 * time.Second + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + // get initial session info + resp := sessionInfo(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var data session.MetadataVerboseWithRefresh + err := json.Unmarshal([]byte(resp.Body), &data) + assert.NoError(t, err) + + // wait until refresh cooldown has reached zero before refresh + waitForRefreshCooldownTimer(t, idp, rpClient) + + resp = sessionRefresh(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var refreshedData session.MetadataVerboseWithRefresh + err = json.Unmarshal([]byte(resp.Body), &refreshedData) + assert.NoError(t, err) + + maxDelta := 5 * time.Second + + assert.True(t, data.Session.Active) + assert.True(t, refreshedData.Session.Active) + + assert.False(t, data.Session.TimeoutAt.IsZero()) + assert.False(t, refreshedData.Session.TimeoutAt.IsZero()) + + expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout) + assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta) + assert.WithinDuration(t, expectedTimeoutAt, refreshedData.Session.TimeoutAt, maxDelta) + + assert.True(t, refreshedData.Session.TimeoutAt.After(data.Session.TimeoutAt)) + + previousTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second + assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(previousTimeoutDuration), maxDelta) + + refreshedTimeoutDuration := time.Duration(refreshedData.Session.TimeoutInSeconds) * time.Second + assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(refreshedTimeoutDuration), maxDelta) +} + func TestHandler_Default(t *testing.T) { up := newUpstream(t) defer up.Server.Close() @@ -689,6 +769,28 @@ func sessionRefresh(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Cli return get(t, rpClient, sessionRefreshURL.String()) } +func waitForRefreshCooldownTimer(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) { + timeout := time.After(5 * time.Second) + ticker := time.Tick(500 * time.Millisecond) + for { + select { + case <-timeout: + assert.Fail(t, "refresh cooldown timer exceeded timeout") + case <-ticker: + resp := sessionInfo(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var temp session.MetadataVerboseWithRefresh + err := json.Unmarshal([]byte(resp.Body), &temp) + assert.NoError(t, err) + + if !temp.Tokens.RefreshCooldown { + return + } + } + } +} + type response struct { Body string Location *url.URL diff --git a/pkg/session/data.go b/pkg/session/data.go index 37abcda..0f3d67d 100644 --- a/pkg/session/data.go +++ b/pkg/session/data.go @@ -109,6 +109,8 @@ type MetadataSession struct { CreatedAt time.Time `json:"created_at"` // EndsAt is the time when the session will end, i.e. the absolute lifetime/time-to-live for the session. EndsAt time.Time `json:"ends_at"` + // TimeoutAt is the time when the session will be marked as inactive. A zero value means no timeout. The timeout is extended whenever the tokens are refreshed. + TimeoutAt time.Time `json:"timeout_at"` } type MetadataTokens struct { @@ -118,7 +120,7 @@ type MetadataTokens struct { RefreshedAt time.Time `json:"refreshed_at"` } -func NewMetadata(expiresIn time.Duration, endsIn time.Duration) *Metadata { +func NewMetadata(expiresIn, endsIn time.Duration) *Metadata { now := time.Now() return &Metadata{ Session: MetadataSession{ @@ -182,22 +184,47 @@ func (in *Metadata) TokenLifetime() time.Duration { return in.Tokens.ExpireAt.Sub(in.Tokens.RefreshedAt) } +func (in *Metadata) ExtendTimeout(duration time.Duration) { + in.Session.TimeoutAt = time.Now().Add(duration) +} + +func (in *Metadata) IsTimedOut() bool { + if in.Session.TimeoutAt.IsZero() { + return false + } + + return time.Now().After(in.Session.TimeoutAt) +} + +func (in *Metadata) WithTimeout(timeoutIn time.Duration) { + in.Session.TimeoutAt = time.Now().Add(timeoutIn) +} + func (in *Metadata) Verbose() MetadataVerbose { now := time.Now() expireTime := in.Tokens.ExpireAt endTime := in.Session.EndsAt + timeoutTime := in.Session.TimeoutAt - return MetadataVerbose{ + mv := MetadataVerbose{ Session: MetadataSessionVerbose{ - MetadataSession: in.Session, - EndsInSeconds: toSeconds(endTime.Sub(now)), + MetadataSession: in.Session, + EndsInSeconds: toSeconds(endTime.Sub(now)), + Active: !in.IsTimedOut(), + TimeoutInSeconds: toSeconds(timeoutTime.Sub(now)), }, Tokens: MetadataTokensVerbose{ MetadataTokens: in.Tokens, ExpireInSeconds: toSeconds(expireTime.Sub(now)), }, } + + if timeoutTime.IsZero() { + mv.Session.TimeoutInSeconds = int64(-1) + } + + return mv } func (in *Metadata) VerboseWithRefresh() MetadataVerboseWithRefresh { @@ -229,7 +256,9 @@ type MetadataVerboseWithRefresh struct { type MetadataSessionVerbose struct { MetadataSession - EndsInSeconds int64 `json:"ends_in_seconds"` + EndsInSeconds int64 `json:"ends_in_seconds"` + Active bool `json:"active"` + TimeoutInSeconds int64 `json:"timeout_in_seconds"` } type MetadataTokensVerbose struct { diff --git a/pkg/session/data_test.go b/pkg/session/data_test.go index d4c1123..e68ac0f 100644 --- a/pkg/session/data_test.go +++ b/pkg/session/data_test.go @@ -25,6 +25,50 @@ func TestData_HasRefreshToken(t *testing.T) { assert.True(t, data.HasRefreshToken()) } +func TestNewMetadata(t *testing.T) { + tokenLifetime := 30 * time.Minute + sessionLifetime := time.Hour + + metadata := session.NewMetadata(tokenLifetime, sessionLifetime) + + maxDelta := time.Second + + expected := time.Now() + actual := metadata.Session.CreatedAt + assert.WithinDuration(t, expected, actual, maxDelta) + + expected = time.Now().Add(sessionLifetime) + actual = metadata.Session.EndsAt + assert.WithinDuration(t, expected, actual, maxDelta) + + assert.True(t, metadata.Session.TimeoutAt.IsZero()) + + expected = time.Now() + actual = metadata.Tokens.RefreshedAt + assert.WithinDuration(t, expected, actual, maxDelta) + + expected = time.Now().Add(tokenLifetime) + actual = metadata.Tokens.ExpireAt + assert.WithinDuration(t, expected, actual, maxDelta) +} + +func TestMetadata_WithTimeout(t *testing.T) { + tokenLifetime := 30 * time.Minute + sessionLifetime := time.Hour + sessionInactivityTimeout := 15 * time.Minute + + metadata := session.NewMetadata(tokenLifetime, sessionLifetime) + metadata.WithTimeout(sessionInactivityTimeout) + + maxDelta := time.Second + + assert.False(t, metadata.Session.TimeoutAt.IsZero()) + + expected := time.Now().Add(sessionInactivityTimeout) + actual := metadata.Session.TimeoutAt + assert.WithinDuration(t, expected, actual, maxDelta) +} + func TestMetadata_IsExpired(t *testing.T) { t.Run("expired", func(t *testing.T) { metadata := session.Metadata{ @@ -214,6 +258,10 @@ func TestMetadata_Verbose(t *testing.T) { expected = time.Now().Add(tokenLifetime) actual = time.Now().Add(durationSeconds(verbose.Tokens.ExpireInSeconds)) assert.WithinDuration(t, expected, actual, maxDelta) + + assert.True(t, verbose.Session.Active) + assert.True(t, verbose.Session.TimeoutAt.IsZero()) + assert.Equal(t, int64(-1), verbose.Session.TimeoutInSeconds) } func TestMetadata_VerboseWithRefresh(t *testing.T) { @@ -255,6 +303,66 @@ func TestMetadata_VerboseWithRefresh(t *testing.T) { }) } +func TestMetadata_Verbose_WithTimeout(t *testing.T) { + tokenLifetime := 30 * time.Minute + sessionLifetime := time.Hour + timeout := 15 * time.Minute + + metadata := session.NewMetadata(tokenLifetime, sessionLifetime) + metadata.WithTimeout(timeout) + + maxDelta := time.Second + + verbose := metadata.Verbose() + + assert.True(t, verbose.Session.Active) + assert.False(t, verbose.Session.TimeoutAt.IsZero()) + + expected := time.Now().Add(timeout) + actual := verbose.Session.TimeoutAt + assert.WithinDuration(t, expected, actual, maxDelta) + + expected = time.Now().Add(durationSeconds(verbose.Session.TimeoutInSeconds)) + actual = verbose.Session.TimeoutAt + assert.WithinDuration(t, expected, actual, maxDelta) +} + +func TestMetadata_ExtendTimeout(t *testing.T) { + tokenLifetime := 30 * time.Minute + sessionLifetime := time.Hour + + timeout := 15 * time.Minute + + metadata := session.NewMetadata(tokenLifetime, sessionLifetime) + metadata.WithTimeout(timeout) + + previousTimeoutAt := metadata.Session.TimeoutAt + + metadata.ExtendTimeout(timeout) + assert.True(t, metadata.Session.TimeoutAt.After(previousTimeoutAt)) +} + +func TestMetadata_IsTimedOut(t *testing.T) { + tokenLifetime := 30 * time.Minute + sessionLifetime := time.Hour + + t.Run("timeout is zero", func(t *testing.T) { + metadata := session.NewMetadata(tokenLifetime, sessionLifetime) + assert.False(t, metadata.IsTimedOut()) + }) + + t.Run("timeout is non-zero", func(t *testing.T) { + timeout := 15 * time.Minute + + metadata := session.NewMetadata(tokenLifetime, sessionLifetime) + metadata.WithTimeout(timeout) + assert.False(t, metadata.IsTimedOut()) + + metadata.WithTimeout(-timeout) + assert.True(t, metadata.IsTimedOut()) + }) +} + func durationSeconds(seconds int64) time.Duration { return time.Duration(seconds) * time.Second } diff --git a/pkg/session/handler.go b/pkg/session/handler.go index 3f55204..76e4927 100644 --- a/pkg/session/handler.go +++ b/pkg/session/handler.go @@ -23,10 +23,11 @@ import ( var ( ErrCookieNotFound = errors.New("cookie not found") - ErrNoSessionData = errors.New("no session data") - ErrNoAccessToken = errors.New("no access token in session data") ErrExpiredAccessToken = errors.New("access token is expired") ErrInvalidState = errors.New("invalid state") + ErrNoSessionData = errors.New("no session data") + ErrNoAccessToken = errors.New("no access token in session data") + ErrSessionInactive = errors.New("session is inactive") ) const ( @@ -36,11 +37,11 @@ const ( ) type Handler struct { - client *openidclient.Client - crypter crypto.Crypter - openidCfg openidconfig.Config - refreshEnabled bool - store Store + cfg config.Session + client *openidclient.Client + crypter crypto.Crypter + openidCfg openidconfig.Config + store Store } func NewHandler(cfg *config.Config, openidCfg openidconfig.Config, crypter crypto.Crypter, openidClient *openidclient.Client) (*Handler, error) { @@ -50,11 +51,11 @@ func NewHandler(cfg *config.Config, openidCfg openidconfig.Config, crypter crypt } return &Handler{ - crypter: crypter, - client: openidClient, - openidCfg: openidCfg, - store: store, - refreshEnabled: cfg.Session.Refresh, + crypter: crypter, + client: openidClient, + openidCfg: openidCfg, + store: store, + cfg: cfg.Session, }, nil } @@ -68,6 +69,11 @@ func (h *Handler) Create(r *http.Request, tokens *openid.Tokens, sessionLifetime key := h.Key(externalSessionID) tokenExpiresIn := time.Until(tokens.Expiry) metadata := NewMetadata(tokenExpiresIn, sessionLifetime) + + if h.cfg.Inactivity { + metadata.WithTimeout(h.cfg.InactivityTimeout) + } + encrypted, err := NewData(externalSessionID, tokens, metadata).Encrypt(h.crypter) if err != nil { return "", fmt.Errorf("encrypting session data: %w", err) @@ -202,12 +208,16 @@ func (h *Handler) GetOrRefresh(r *http.Request) (*Data, error) { return nil, err } + if h.isTimedOut(sessionData) { + return nil, ErrSessionInactive + } + if !h.shouldRefresh(sessionData) { return sessionData, nil } refreshed, err := h.Refresh(r, key, sessionData) - if errors.Is(err, ErrInvalidState) { + if errors.Is(err, ErrInvalidState) || errors.Is(err, ErrSessionInactive) { return nil, err } else if err != nil { mw.LogEntryFrom(r).Warnf("session: could not refresh tokens; falling back to existing token: %+v", err) @@ -287,7 +297,7 @@ func (h *Handler) Refresh(r *http.Request, key string, data *Data) (*Data, error }(lock, ctx) // Get the latest session state again in case it was changed while acquiring the lock - data, err = h.Get(r) + data, err = h.GetForKey(r, key) if err != nil { return nil, err } @@ -297,6 +307,10 @@ func (h *Handler) Refresh(r *http.Request, key string, data *Data) (*Data, error return data, nil } + if h.isTimedOut(data) { + return nil, ErrSessionInactive + } + logger.Debug("session: performing refresh grant...") var resp *openid.TokenResponse refresh := func(ctx context.Context) error { @@ -318,9 +332,23 @@ func (h *Handler) Refresh(r *http.Request, key string, data *Data) (*Data, error data.RefreshToken = resp.RefreshToken data.Metadata.Refresh(resp.ExpiresIn) + if h.cfg.Inactivity { + data.Metadata.ExtendTimeout(h.cfg.InactivityTimeout) + } + + err = h.Update(ctx, key, data) + if err != nil { + return nil, err + } + + logger.Info("session: successfully refreshed") + return data, nil +} + +func (h *Handler) Update(ctx context.Context, key string, data *Data) error { encrypted, err := data.Encrypt(h.crypter) if err != nil { - return nil, fmt.Errorf("encrypting session data: %w", err) + return fmt.Errorf("encrypting session data: %w", err) } update := func(ctx context.Context) error { @@ -332,19 +360,22 @@ func (h *Handler) Refresh(r *http.Request, key string, data *Data) (*Data, error } if err := retry.Do(ctx, retrypkg.DefaultBackoff, update); err != nil { - return nil, fmt.Errorf("updating in store: %w", err) + return fmt.Errorf("updating in store: %w", err) } - logger.Info("session: successfully refreshed") - return data, nil + return nil } func (h *Handler) canRefresh(data *Data) bool { - return h.refreshEnabled && data.HasRefreshToken() && !data.Metadata.IsRefreshOnCooldown() + return h.cfg.Refresh && data.HasRefreshToken() && !data.Metadata.IsRefreshOnCooldown() } func (h *Handler) shouldRefresh(data *Data) bool { - return h.refreshEnabled && data.HasRefreshToken() && data.Metadata.ShouldRefresh() + return h.cfg.Refresh && data.HasRefreshToken() && data.Metadata.ShouldRefresh() +} + +func (h *Handler) isTimedOut(data *Data) bool { + return h.cfg.Inactivity && data.Metadata.IsTimedOut() } func NewSessionID(cfg openidconfig.Provider, idToken *openid.IDToken, params url.Values) (string, error) {