feat: rudimentary support for refresh tokens

This commit is contained in:
Trong Huu Nguyen
2022-08-25 11:30:04 +02:00
parent dc0741f79f
commit d5bbca9897
30 changed files with 1048 additions and 335 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"net/http"
"net/http/httputil"
"time"
"github.com/nais/wonderwall/pkg/autologin"
"github.com/nais/wonderwall/pkg/config"
@@ -35,7 +36,6 @@ func NewHandler(
cfg *config.Config,
openidConfig openidconfig.Config,
crypter crypto.Crypter,
sessionHandler *session.Handler,
) (*Handler, error) {
openidProvider, err := provider.NewProvider(ctx, openidConfig)
if err != nil {
@@ -47,13 +47,25 @@ func NewHandler(
return nil, err
}
httpClient := &http.Client{
Timeout: time.Second * 10,
}
openidClient := client.NewClient(openidConfig)
openidClient.SetHttpClient(httpClient)
sessionHandler, err := session.NewHandler(cfg, openidConfig, crypter, openidClient)
if err != nil {
return nil, err
}
return &Handler{
AutoLogin: autoLogin,
Client: client.NewClient(openidConfig),
Client: openidClient,
Config: cfg,
CookieOptions: cookie.DefaultOptions(),
Crypter: crypter,
Loginstatus: loginstatus.NewClient(cfg.Loginstatus, http.DefaultClient),
Loginstatus: loginstatus.NewClient(cfg.Loginstatus, httpClient),
OpenIDConfig: openidConfig,
Provider: openidProvider,
ReverseProxy: newReverseProxy(cfg.UpstreamHost),

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net/http"
"time"
"github.com/sethvargo/go-retry"
log "github.com/sirupsen/logrus"
@@ -56,15 +55,16 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
expiresIn := h.getSessionLifetime(tokens.Expiry)
key, err := h.Sessions.Create(r, tokens, expiresIn)
sessionLifetime := h.Config.Session.MaxLifetime
key, err := h.Sessions.Create(r, tokens, sessionLifetime)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
return
}
opts := h.CookieOptsPathAware(r).
WithExpiresIn(expiresIn)
WithExpiresIn(sessionLifetime)
err = cookie.EncryptAndSet(w, cookie.Session, key, opts, h.Crypter)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: setting session cookie: %w", err))
@@ -92,11 +92,7 @@ func (h *Handler) redeemValidTokens(r *http.Request, loginCallback client.LoginC
retryable := func(ctx context.Context) error {
tokens, err = loginCallback.RedeemTokens(ctx)
if err != nil {
return retry.RetryableError(err)
}
return nil
return retry.RetryableError(err)
}
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
@@ -106,18 +102,6 @@ func (h *Handler) redeemValidTokens(r *http.Request, loginCallback client.LoginC
return tokens, nil
}
func (h *Handler) getSessionLifetime(tokenExpiry time.Time) time.Duration {
defaultSessionLifetime := h.Config.SessionMaxLifetime
tokenDuration := tokenExpiry.Sub(time.Now())
if tokenDuration <= defaultSessionLifetime {
return tokenDuration
}
return defaultSessionLifetime
}
func (h *Handler) getLoginstatusToken(r *http.Request, tokens *openid.Tokens) (*loginstatus.TokenResponse, error) {
var tokenResponse *loginstatus.TokenResponse
@@ -125,11 +109,7 @@ func (h *Handler) getLoginstatusToken(r *http.Request, tokens *openid.Tokens) (*
var err error
tokenResponse, err = h.Loginstatus.ExchangeToken(ctx, tokens.AccessToken)
if err != nil {
return retry.RetryableError(err)
}
return nil
return retry.RetryableError(err)
}
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
return nil, err

View File

@@ -49,8 +49,8 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
}
func (h *Handler) accessToken(r *http.Request, logger *log.Entry) (string, bool) {
sessionData, err := h.Sessions.Get(r)
if err == nil && sessionData != nil && len(sessionData.AccessToken) > 0 {
sessionData, err := h.Sessions.GetOrRefresh(r)
if err == nil && sessionData != nil && sessionData.HasAccessToken() {
return sessionData.AccessToken, true
}

View File

@@ -0,0 +1,36 @@
package handler
import (
"encoding/json"
"errors"
"net/http"
mw "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/session"
)
// SessionInfo returns metadata for the current user's session.
func (h *Handler) SessionInfo(w http.ResponseWriter, r *http.Request) {
logger := mw.LogEntryFrom(r)
data, err := h.Sessions.Get(r)
if err != nil {
if errors.Is(err, session.CookieNotFoundError) || errors.Is(err, session.KeyNotFoundError) {
logger.Infof("session/info: getting session: %+v", err)
w.WriteHeader(http.StatusUnauthorized)
return
}
logger.Warnf("session/info: getting session: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(data.Metadata.Verbose())
if err != nil {
logger.Warnf("session/info: marshalling metadata: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}

View File

@@ -0,0 +1,50 @@
package handler
import (
"encoding/json"
"errors"
"net/http"
mw "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/session"
)
// SessionRefresh refreshes current user's session and returns the associated updated metadata.
func (h *Handler) SessionRefresh(w http.ResponseWriter, r *http.Request) {
logger := mw.LogEntryFrom(r)
key, err := h.Sessions.GetKey(r)
if err != nil {
logger.Infof("session/refresh: getting key: %+v", err)
w.WriteHeader(http.StatusUnauthorized)
return
}
data, err := h.Sessions.Get(r)
if err != nil {
if errors.Is(err, session.KeyNotFoundError) {
logger.Infof("session/refresh: getting session: %+v", err)
w.WriteHeader(http.StatusUnauthorized)
return
}
logger.Warnf("session/refresh: getting session: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
data, err = h.Sessions.Refresh(r, key, data)
if err != nil {
logger.Warnf("session/refresh: refreshing: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(data.Metadata.Verbose())
if err != nil {
logger.Warnf("session/refresh: marshalling metadata: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}

View File

@@ -2,6 +2,7 @@ package handler_test
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
@@ -10,12 +11,14 @@ import (
"net/url"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/cookie"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/session"
)
func TestHandler_Login(t *testing.T) {
@@ -122,11 +125,9 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
return data.ExternalSessionID
}
frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL)
frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/frontchannel")
assert.NoError(t, err)
frontchannelLogoutURL.Path = "/oauth2/logout/frontchannel"
req := idp.GetRequest(frontchannelLogoutURL.String())
values := url.Values{}
@@ -154,6 +155,161 @@ func TestHandler_SessionStateRequired(t *testing.T) {
assert.NotEmpty(t, sessionState)
}
func TestHandler_SessionInfo(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Minute
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerbose
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now(), data.SessionCreatedAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.SessionEndsAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.TokensExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), data.TokensRefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(data.SessionEndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(data.TokensExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
// 1 second < next token refresh <= seconds until token expires
assert.LessOrEqual(t, data.TokensNextRefreshInSeconds, data.TokensExpireInSeconds)
assert.Greater(t, data.TokensNextRefreshInSeconds, int64(1))
assert.True(t, data.TokensRefreshCooldown)
// 1 second < refresh cooldown <= minimum refresh interval
assert.LessOrEqual(t, data.TokensRefreshCooldownSeconds, session.RefreshMinInterval)
assert.Greater(t, data.TokensRefreshCooldownSeconds, int64(1))
}
func TestHandler_SessionInfo_Disabled(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = false
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
}
func TestHandler_SessionRefresh(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
// get initial session info
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerbose
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
// wait until refresh cooldown has reached zero before refresh
func() {
timeout := time.After(5 * time.Second)
ticker := time.Tick(500 * time.Millisecond)
for {
select {
case <-timeout:
assert.Fail(t, "refresh cooldown timer exceeded timeout")
case <-ticker:
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var temp session.MetadataVerbose
err = json.Unmarshal([]byte(resp.Body), &temp)
assert.NoError(t, err)
if !temp.TokensRefreshCooldown {
return
}
}
}
}()
resp = sessionRefresh(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var refreshedData session.MetadataVerbose
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
assert.NoError(t, err)
// session create and end times should be unchanged
assert.WithinDuration(t, data.SessionCreatedAt, refreshedData.SessionCreatedAt, 0)
assert.WithinDuration(t, data.SessionEndsAt, refreshedData.SessionEndsAt, 0)
// token expiration and refresh times should be later than before
assert.True(t, refreshedData.TokensExpireAt.After(data.TokensExpireAt))
assert.True(t, refreshedData.TokensRefreshedAt.After(data.TokensRefreshedAt))
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), refreshedData.TokensExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), refreshedData.TokensRefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(refreshedData.SessionEndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(refreshedData.TokensExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
// 1 second < next token refresh <= seconds until token expires
assert.LessOrEqual(t, refreshedData.TokensNextRefreshInSeconds, refreshedData.TokensExpireInSeconds)
assert.Greater(t, refreshedData.TokensNextRefreshInSeconds, int64(1))
assert.True(t, refreshedData.TokensRefreshCooldown)
// 1 second < refresh cooldown <= minimum refresh interval
assert.LessOrEqual(t, refreshedData.TokensRefreshCooldownSeconds, session.RefreshMinInterval)
assert.Greater(t, refreshedData.TokensRefreshCooldownSeconds, int64(1))
}
func TestHandler_SessionRefresh_Disabled(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = false
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionRefresh(t, idp, rpClient)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
}
func TestHandler_Default(t *testing.T) {
up := newUpstream(t)
defer up.Server.Close()
@@ -413,6 +569,20 @@ func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) {
assert.Nil(t, sessionCookie)
}
func sessionInfo(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) response {
sessionInfoURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/session")
assert.NoError(t, err)
return get(t, rpClient, sessionInfoURL.String())
}
func sessionRefresh(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) response {
sessionRefreshURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/session/refresh")
assert.NoError(t, err)
return get(t, rpClient, sessionRefreshURL.String())
}
type response struct {
Body string
Location *url.URL