mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-14 12:26:34 +00:00
feat: rudimentary support for refresh tokens
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/autologin"
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
@@ -35,7 +36,6 @@ func NewHandler(
|
||||
cfg *config.Config,
|
||||
openidConfig openidconfig.Config,
|
||||
crypter crypto.Crypter,
|
||||
sessionHandler *session.Handler,
|
||||
) (*Handler, error) {
|
||||
openidProvider, err := provider.NewProvider(ctx, openidConfig)
|
||||
if err != nil {
|
||||
@@ -47,13 +47,25 @@ func NewHandler(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Timeout: time.Second * 10,
|
||||
}
|
||||
|
||||
openidClient := client.NewClient(openidConfig)
|
||||
openidClient.SetHttpClient(httpClient)
|
||||
|
||||
sessionHandler, err := session.NewHandler(cfg, openidConfig, crypter, openidClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
AutoLogin: autoLogin,
|
||||
Client: client.NewClient(openidConfig),
|
||||
Client: openidClient,
|
||||
Config: cfg,
|
||||
CookieOptions: cookie.DefaultOptions(),
|
||||
Crypter: crypter,
|
||||
Loginstatus: loginstatus.NewClient(cfg.Loginstatus, http.DefaultClient),
|
||||
Loginstatus: loginstatus.NewClient(cfg.Loginstatus, httpClient),
|
||||
OpenIDConfig: openidConfig,
|
||||
Provider: openidProvider,
|
||||
ReverseProxy: newReverseProxy(cfg.UpstreamHost),
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/sethvargo/go-retry"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -56,15 +55,16 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
expiresIn := h.getSessionLifetime(tokens.Expiry)
|
||||
key, err := h.Sessions.Create(r, tokens, expiresIn)
|
||||
sessionLifetime := h.Config.Session.MaxLifetime
|
||||
|
||||
key, err := h.Sessions.Create(r, tokens, sessionLifetime)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
opts := h.CookieOptsPathAware(r).
|
||||
WithExpiresIn(expiresIn)
|
||||
WithExpiresIn(sessionLifetime)
|
||||
err = cookie.EncryptAndSet(w, cookie.Session, key, opts, h.Crypter)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("callback: setting session cookie: %w", err))
|
||||
@@ -92,11 +92,7 @@ func (h *Handler) redeemValidTokens(r *http.Request, loginCallback client.LoginC
|
||||
|
||||
retryable := func(ctx context.Context) error {
|
||||
tokens, err = loginCallback.RedeemTokens(ctx)
|
||||
if err != nil {
|
||||
return retry.RetryableError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return retry.RetryableError(err)
|
||||
}
|
||||
|
||||
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
|
||||
@@ -106,18 +102,6 @@ func (h *Handler) redeemValidTokens(r *http.Request, loginCallback client.LoginC
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getSessionLifetime(tokenExpiry time.Time) time.Duration {
|
||||
defaultSessionLifetime := h.Config.SessionMaxLifetime
|
||||
|
||||
tokenDuration := tokenExpiry.Sub(time.Now())
|
||||
|
||||
if tokenDuration <= defaultSessionLifetime {
|
||||
return tokenDuration
|
||||
}
|
||||
|
||||
return defaultSessionLifetime
|
||||
}
|
||||
|
||||
func (h *Handler) getLoginstatusToken(r *http.Request, tokens *openid.Tokens) (*loginstatus.TokenResponse, error) {
|
||||
var tokenResponse *loginstatus.TokenResponse
|
||||
|
||||
@@ -125,11 +109,7 @@ func (h *Handler) getLoginstatusToken(r *http.Request, tokens *openid.Tokens) (*
|
||||
var err error
|
||||
|
||||
tokenResponse, err = h.Loginstatus.ExchangeToken(ctx, tokens.AccessToken)
|
||||
if err != nil {
|
||||
return retry.RetryableError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return retry.RetryableError(err)
|
||||
}
|
||||
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -49,8 +49,8 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *Handler) accessToken(r *http.Request, logger *log.Entry) (string, bool) {
|
||||
sessionData, err := h.Sessions.Get(r)
|
||||
if err == nil && sessionData != nil && len(sessionData.AccessToken) > 0 {
|
||||
sessionData, err := h.Sessions.GetOrRefresh(r)
|
||||
if err == nil && sessionData != nil && sessionData.HasAccessToken() {
|
||||
return sessionData.AccessToken, true
|
||||
}
|
||||
|
||||
|
||||
36
pkg/handler/handler_session_info.go
Normal file
36
pkg/handler/handler_session_info.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
mw "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
// SessionInfo returns metadata for the current user's session.
|
||||
func (h *Handler) SessionInfo(w http.ResponseWriter, r *http.Request) {
|
||||
logger := mw.LogEntryFrom(r)
|
||||
|
||||
data, err := h.Sessions.Get(r)
|
||||
if err != nil {
|
||||
if errors.Is(err, session.CookieNotFoundError) || errors.Is(err, session.KeyNotFoundError) {
|
||||
logger.Infof("session/info: getting session: %+v", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Warnf("session/info: getting session: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(data.Metadata.Verbose())
|
||||
if err != nil {
|
||||
logger.Warnf("session/info: marshalling metadata: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
50
pkg/handler/handler_session_refresh.go
Normal file
50
pkg/handler/handler_session_refresh.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
mw "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
// SessionRefresh refreshes current user's session and returns the associated updated metadata.
|
||||
func (h *Handler) SessionRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
logger := mw.LogEntryFrom(r)
|
||||
|
||||
key, err := h.Sessions.GetKey(r)
|
||||
if err != nil {
|
||||
logger.Infof("session/refresh: getting key: %+v", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := h.Sessions.Get(r)
|
||||
if err != nil {
|
||||
if errors.Is(err, session.KeyNotFoundError) {
|
||||
logger.Infof("session/refresh: getting session: %+v", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Warnf("session/refresh: getting session: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
data, err = h.Sessions.Refresh(r, key, data)
|
||||
if err != nil {
|
||||
logger.Warnf("session/refresh: refreshing: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(data.Metadata.Verbose())
|
||||
if err != nil {
|
||||
logger.Warnf("session/refresh: marshalling metadata: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package handler_test
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -10,12 +11,14 @@ import (
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
func TestHandler_Login(t *testing.T) {
|
||||
@@ -122,11 +125,9 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
|
||||
return data.ExternalSessionID
|
||||
}
|
||||
|
||||
frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL)
|
||||
frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/frontchannel")
|
||||
assert.NoError(t, err)
|
||||
|
||||
frontchannelLogoutURL.Path = "/oauth2/logout/frontchannel"
|
||||
|
||||
req := idp.GetRequest(frontchannelLogoutURL.String())
|
||||
|
||||
values := url.Values{}
|
||||
@@ -154,6 +155,161 @@ func TestHandler_SessionStateRequired(t *testing.T) {
|
||||
assert.NotEmpty(t, sessionState)
|
||||
}
|
||||
|
||||
func TestHandler_SessionInfo(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = true
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Minute
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var data session.MetadataVerbose
|
||||
err := json.Unmarshal([]byte(resp.Body), &data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
allowedSkew := 5 * time.Second
|
||||
assert.WithinDuration(t, time.Now(), data.SessionCreatedAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.SessionEndsAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.TokensExpireAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now(), data.TokensRefreshedAt, allowedSkew)
|
||||
|
||||
sessionEndDuration := time.Duration(data.SessionEndsInSeconds) * time.Second
|
||||
// 1 second < time until session ends <= configured max session lifetime
|
||||
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
|
||||
assert.Greater(t, sessionEndDuration, time.Second)
|
||||
|
||||
tokenExpiryDuration := time.Duration(data.TokensExpireInSeconds) * time.Second
|
||||
// 1 second < time until token expires <= max duration for tokens from IDP
|
||||
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
|
||||
assert.Greater(t, tokenExpiryDuration, time.Second)
|
||||
|
||||
// 1 second < next token refresh <= seconds until token expires
|
||||
assert.LessOrEqual(t, data.TokensNextRefreshInSeconds, data.TokensExpireInSeconds)
|
||||
assert.Greater(t, data.TokensNextRefreshInSeconds, int64(1))
|
||||
|
||||
assert.True(t, data.TokensRefreshCooldown)
|
||||
// 1 second < refresh cooldown <= minimum refresh interval
|
||||
assert.LessOrEqual(t, data.TokensRefreshCooldownSeconds, session.RefreshMinInterval)
|
||||
assert.Greater(t, data.TokensRefreshCooldownSeconds, int64(1))
|
||||
}
|
||||
|
||||
func TestHandler_SessionInfo_Disabled(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = false
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Second
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestHandler_SessionRefresh(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = true
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Second
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
// get initial session info
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var data session.MetadataVerbose
|
||||
err := json.Unmarshal([]byte(resp.Body), &data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// wait until refresh cooldown has reached zero before refresh
|
||||
func() {
|
||||
timeout := time.After(5 * time.Second)
|
||||
ticker := time.Tick(500 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
assert.Fail(t, "refresh cooldown timer exceeded timeout")
|
||||
case <-ticker:
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var temp session.MetadataVerbose
|
||||
err = json.Unmarshal([]byte(resp.Body), &temp)
|
||||
assert.NoError(t, err)
|
||||
|
||||
if !temp.TokensRefreshCooldown {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
resp = sessionRefresh(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var refreshedData session.MetadataVerbose
|
||||
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// session create and end times should be unchanged
|
||||
assert.WithinDuration(t, data.SessionCreatedAt, refreshedData.SessionCreatedAt, 0)
|
||||
assert.WithinDuration(t, data.SessionEndsAt, refreshedData.SessionEndsAt, 0)
|
||||
|
||||
// token expiration and refresh times should be later than before
|
||||
assert.True(t, refreshedData.TokensExpireAt.After(data.TokensExpireAt))
|
||||
assert.True(t, refreshedData.TokensRefreshedAt.After(data.TokensRefreshedAt))
|
||||
|
||||
allowedSkew := 5 * time.Second
|
||||
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), refreshedData.TokensExpireAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now(), refreshedData.TokensRefreshedAt, allowedSkew)
|
||||
|
||||
sessionEndDuration := time.Duration(refreshedData.SessionEndsInSeconds) * time.Second
|
||||
// 1 second < time until session ends <= configured max session lifetime
|
||||
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
|
||||
assert.Greater(t, sessionEndDuration, time.Second)
|
||||
|
||||
tokenExpiryDuration := time.Duration(refreshedData.TokensExpireInSeconds) * time.Second
|
||||
// 1 second < time until token expires <= max duration for tokens from IDP
|
||||
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
|
||||
assert.Greater(t, tokenExpiryDuration, time.Second)
|
||||
|
||||
// 1 second < next token refresh <= seconds until token expires
|
||||
assert.LessOrEqual(t, refreshedData.TokensNextRefreshInSeconds, refreshedData.TokensExpireInSeconds)
|
||||
assert.Greater(t, refreshedData.TokensNextRefreshInSeconds, int64(1))
|
||||
|
||||
assert.True(t, refreshedData.TokensRefreshCooldown)
|
||||
// 1 second < refresh cooldown <= minimum refresh interval
|
||||
assert.LessOrEqual(t, refreshedData.TokensRefreshCooldownSeconds, session.RefreshMinInterval)
|
||||
assert.Greater(t, refreshedData.TokensRefreshCooldownSeconds, int64(1))
|
||||
}
|
||||
|
||||
func TestHandler_SessionRefresh_Disabled(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = false
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Second
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := sessionRefresh(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestHandler_Default(t *testing.T) {
|
||||
up := newUpstream(t)
|
||||
defer up.Server.Close()
|
||||
@@ -413,6 +569,20 @@ func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) {
|
||||
assert.Nil(t, sessionCookie)
|
||||
}
|
||||
|
||||
func sessionInfo(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) response {
|
||||
sessionInfoURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/session")
|
||||
assert.NoError(t, err)
|
||||
|
||||
return get(t, rpClient, sessionInfoURL.String())
|
||||
}
|
||||
|
||||
func sessionRefresh(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) response {
|
||||
sessionRefreshURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/session/refresh")
|
||||
assert.NoError(t, err)
|
||||
|
||||
return get(t, rpClient, sessionRefreshURL.String())
|
||||
}
|
||||
|
||||
type response struct {
|
||||
Body string
|
||||
Location *url.URL
|
||||
|
||||
Reference in New Issue
Block a user