From d5bbca9897274cd38405a44bf592ec1f5df3f18c Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Thu, 25 Aug 2022 11:30:04 +0200 Subject: [PATCH] feat: rudimentary support for refresh tokens --- README.md | 7 +- cmd/wonderwall/main.go | 8 +- pkg/config/config.go | 24 ++- pkg/handler/handler.go | 18 +- pkg/handler/handler_callback.go | 32 +--- pkg/handler/handler_default.go | 4 +- pkg/handler/handler_session_info.go | 36 ++++ pkg/handler/handler_session_refresh.go | 50 ++++++ pkg/handler/handler_test.go | 176 ++++++++++++++++++- pkg/mock/config.go | 4 +- pkg/mock/openid.go | 227 ++++++++++++++++++------- pkg/openid/client/client.go | 63 ++++++- pkg/openid/client/login_callback.go | 6 +- pkg/openid/config/client.go | 1 + pkg/openid/params.go | 3 + pkg/openid/params_values.go | 6 + pkg/openid/response.go | 17 ++ pkg/openid/scopes/scopes.go | 5 + pkg/router/paths/paths.go | 2 + pkg/router/router.go | 5 + pkg/session/data.go | 202 ++++++++++++++++++++++ pkg/session/data_test.go | 41 +++++ pkg/session/handler.go | 130 +++++++++++--- pkg/session/models.go | 105 ------------ pkg/session/store.go | 1 + pkg/session/store_memory.go | 8 + pkg/session/store_memory_test.go | 50 ++---- pkg/session/store_redis.go | 11 ++ pkg/session/store_redis_test.go | 50 ++---- pkg/session/store_test.go | 91 ++++++++++ 30 files changed, 1048 insertions(+), 335 deletions(-) create mode 100644 pkg/handler/handler_session_info.go create mode 100644 pkg/handler/handler_session_refresh.go create mode 100644 pkg/openid/params_values.go create mode 100644 pkg/openid/response.go create mode 100644 pkg/session/data.go create mode 100644 pkg/session/data_test.go delete mode 100644 pkg/session/models.go create mode 100644 pkg/session/store_test.go diff --git a/README.md b/README.md index 5f85be2..220d1c8 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ The following flags are available: --bind-address string Listen address for public connections. (default "127.0.0.1:3000") --encryption-key string Base64 encoded 256-bit cookie encryption key; must be identical in instances that share session store. --error-redirect-uri string URI to redirect user to on errors for custom error handling. ---ingress string Comma separated list of ingresses used to access the main application. +--ingress strings Comma separated list of ingresses used to access the main application. --log-format string Log format, either 'json' or 'text'. (default "json") --log-level string Logging verbosity level. (default "info") --loginstatus.cookie-domain string The domain that the cookie should be set for. @@ -102,10 +102,13 @@ The following flags are available: --redis.password string Password for Redis. --redis.tls Whether or not to use TLS for connecting to Redis. (default true) --redis.username string Username for Redis. ---session-max-lifetime duration Max lifetime for user sessions. (default 1h0m0s) +--session.max-lifetime duration Max lifetime for user sessions. (default 1h0m0s) +--session.refresh Automatically refresh the tokens for user sessions if they are expired, as long as the session exists (indicated by the session max lifetime). --upstream-host string Address of upstream host. (default "127.0.0.1:8080") ``` +Boolean flags/options are by default set to `false` unless noted otherwise. + At minimum, the following configuration must be provided: - `openid.client-id` diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index 65156e2..700c3dc 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -13,7 +13,6 @@ import ( openidconfig "github.com/nais/wonderwall/pkg/openid/config" "github.com/nais/wonderwall/pkg/router" "github.com/nais/wonderwall/pkg/server" - "github.com/nais/wonderwall/pkg/session" ) func run() error { @@ -36,12 +35,7 @@ func run() error { defer cancel() crypt := crypto.NewCrypter(key) - sessionHandler, err := session.NewHandler(cfg, openidConfig, crypt) - if err != nil { - return err - } - - h, err := handler.NewHandler(ctx, cfg, openidConfig, crypt, sessionHandler) + h, err := handler.NewHandler(ctx, cfg, openidConfig, crypt) if err != nil { return fmt.Errorf("initializing routing handler: %w", err) } diff --git a/pkg/config/config.go b/pkg/config/config.go index 9c36676..cef7775 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -17,13 +17,13 @@ type Config struct { LogLevel string `json:"log-level"` MetricsBindAddress string `json:"metrics-bind-address"` - AutoLogin bool `json:"auto-login"` - AutoLoginIgnorePaths []string `json:"auto-login-ignore-paths"` - EncryptionKey string `json:"encryption-key"` - ErrorRedirectURI string `json:"error-redirect-uri"` - Ingresses []string `json:"ingress"` - SessionMaxLifetime time.Duration `json:"session-max-lifetime"` - UpstreamHost string `json:"upstream-host"` + AutoLogin bool `json:"auto-login"` + AutoLoginIgnorePaths []string `json:"auto-login-ignore-paths"` + EncryptionKey string `json:"encryption-key"` + ErrorRedirectURI string `json:"error-redirect-uri"` + Ingresses []string `json:"ingress"` + Session Session `json:"session"` + UpstreamHost string `json:"upstream-host"` OpenID OpenID `json:"openid"` Redis Redis `json:"redis"` @@ -39,6 +39,11 @@ type Loginstatus struct { TokenURL string `json:"token-url"` } +type Session struct { + MaxLifetime time.Duration `json:"max-lifetime"` + Refresh bool `json:"refresh"` +} + const ( BindAddress = "bind-address" LogFormat = "log-format" @@ -50,9 +55,11 @@ const ( EncryptionKey = "encryption-key" ErrorRedirectURI = "error-redirect-uri" Ingress = "ingress" - SessionMaxLifetime = "session-max-lifetime" UpstreamHost = "upstream-host" + SessionMaxLifetime = "session.max-lifetime" + SessionRefresh = "session.refresh" + LoginstatusEnabled = "loginstatus.enabled" LoginstatusCookieDomain = "loginstatus.cookie-domain" LoginstatusCookieName = "loginstatus.cookie-name" @@ -74,6 +81,7 @@ func Initialize() (*Config, error) { flag.String(ErrorRedirectURI, "", "URI to redirect user to on errors for custom error handling.") flag.StringSlice(Ingress, []string{}, "Comma separated list of ingresses used to access the main application.") flag.Duration(SessionMaxLifetime, time.Hour, "Max lifetime for user sessions.") + flag.Bool(SessionRefresh, false, "Automatically refresh the tokens for user sessions if they are expired, as long as the session exists (indicated by the session max lifetime).") flag.String(UpstreamHost, "127.0.0.1:8080", "Address of upstream host.") flag.Bool(LoginstatusEnabled, false, "Feature toggle for Loginstatus, a separate service that should provide an opaque token to indicate that a user has been authenticated previously, e.g. by another application in another subdomain.") diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 821bc49..c19e0d3 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/http/httputil" + "time" "github.com/nais/wonderwall/pkg/autologin" "github.com/nais/wonderwall/pkg/config" @@ -35,7 +36,6 @@ func NewHandler( cfg *config.Config, openidConfig openidconfig.Config, crypter crypto.Crypter, - sessionHandler *session.Handler, ) (*Handler, error) { openidProvider, err := provider.NewProvider(ctx, openidConfig) if err != nil { @@ -47,13 +47,25 @@ func NewHandler( return nil, err } + httpClient := &http.Client{ + Timeout: time.Second * 10, + } + + openidClient := client.NewClient(openidConfig) + openidClient.SetHttpClient(httpClient) + + sessionHandler, err := session.NewHandler(cfg, openidConfig, crypter, openidClient) + if err != nil { + return nil, err + } + return &Handler{ AutoLogin: autoLogin, - Client: client.NewClient(openidConfig), + Client: openidClient, Config: cfg, CookieOptions: cookie.DefaultOptions(), Crypter: crypter, - Loginstatus: loginstatus.NewClient(cfg.Loginstatus, http.DefaultClient), + Loginstatus: loginstatus.NewClient(cfg.Loginstatus, httpClient), OpenIDConfig: openidConfig, Provider: openidProvider, ReverseProxy: newReverseProxy(cfg.UpstreamHost), diff --git a/pkg/handler/handler_callback.go b/pkg/handler/handler_callback.go index 7bcb9ea..43fc744 100644 --- a/pkg/handler/handler_callback.go +++ b/pkg/handler/handler_callback.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/http" - "time" "github.com/sethvargo/go-retry" log "github.com/sirupsen/logrus" @@ -56,15 +55,16 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - expiresIn := h.getSessionLifetime(tokens.Expiry) - key, err := h.Sessions.Create(r, tokens, expiresIn) + sessionLifetime := h.Config.Session.MaxLifetime + + key, err := h.Sessions.Create(r, tokens, sessionLifetime) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err)) return } opts := h.CookieOptsPathAware(r). - WithExpiresIn(expiresIn) + WithExpiresIn(sessionLifetime) err = cookie.EncryptAndSet(w, cookie.Session, key, opts, h.Crypter) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: setting session cookie: %w", err)) @@ -92,11 +92,7 @@ func (h *Handler) redeemValidTokens(r *http.Request, loginCallback client.LoginC retryable := func(ctx context.Context) error { tokens, err = loginCallback.RedeemTokens(ctx) - if err != nil { - return retry.RetryableError(err) - } - - return nil + return retry.RetryableError(err) } if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { @@ -106,18 +102,6 @@ func (h *Handler) redeemValidTokens(r *http.Request, loginCallback client.LoginC return tokens, nil } -func (h *Handler) getSessionLifetime(tokenExpiry time.Time) time.Duration { - defaultSessionLifetime := h.Config.SessionMaxLifetime - - tokenDuration := tokenExpiry.Sub(time.Now()) - - if tokenDuration <= defaultSessionLifetime { - return tokenDuration - } - - return defaultSessionLifetime -} - func (h *Handler) getLoginstatusToken(r *http.Request, tokens *openid.Tokens) (*loginstatus.TokenResponse, error) { var tokenResponse *loginstatus.TokenResponse @@ -125,11 +109,7 @@ func (h *Handler) getLoginstatusToken(r *http.Request, tokens *openid.Tokens) (* var err error tokenResponse, err = h.Loginstatus.ExchangeToken(ctx, tokens.AccessToken) - if err != nil { - return retry.RetryableError(err) - } - - return nil + return retry.RetryableError(err) } if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { return nil, err diff --git a/pkg/handler/handler_default.go b/pkg/handler/handler_default.go index 615c0d3..bd899dc 100644 --- a/pkg/handler/handler_default.go +++ b/pkg/handler/handler_default.go @@ -49,8 +49,8 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { } func (h *Handler) accessToken(r *http.Request, logger *log.Entry) (string, bool) { - sessionData, err := h.Sessions.Get(r) - if err == nil && sessionData != nil && len(sessionData.AccessToken) > 0 { + sessionData, err := h.Sessions.GetOrRefresh(r) + if err == nil && sessionData != nil && sessionData.HasAccessToken() { return sessionData.AccessToken, true } diff --git a/pkg/handler/handler_session_info.go b/pkg/handler/handler_session_info.go new file mode 100644 index 0000000..757da66 --- /dev/null +++ b/pkg/handler/handler_session_info.go @@ -0,0 +1,36 @@ +package handler + +import ( + "encoding/json" + "errors" + "net/http" + + mw "github.com/nais/wonderwall/pkg/middleware" + "github.com/nais/wonderwall/pkg/session" +) + +// SessionInfo returns metadata for the current user's session. +func (h *Handler) SessionInfo(w http.ResponseWriter, r *http.Request) { + logger := mw.LogEntryFrom(r) + + data, err := h.Sessions.Get(r) + if err != nil { + if errors.Is(err, session.CookieNotFoundError) || errors.Is(err, session.KeyNotFoundError) { + logger.Infof("session/info: getting session: %+v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + + logger.Warnf("session/info: getting session: %+v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(data.Metadata.Verbose()) + if err != nil { + logger.Warnf("session/info: marshalling metadata: %+v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } +} diff --git a/pkg/handler/handler_session_refresh.go b/pkg/handler/handler_session_refresh.go new file mode 100644 index 0000000..0fdb879 --- /dev/null +++ b/pkg/handler/handler_session_refresh.go @@ -0,0 +1,50 @@ +package handler + +import ( + "encoding/json" + "errors" + "net/http" + + mw "github.com/nais/wonderwall/pkg/middleware" + "github.com/nais/wonderwall/pkg/session" +) + +// SessionRefresh refreshes current user's session and returns the associated updated metadata. +func (h *Handler) SessionRefresh(w http.ResponseWriter, r *http.Request) { + logger := mw.LogEntryFrom(r) + + key, err := h.Sessions.GetKey(r) + if err != nil { + logger.Infof("session/refresh: getting key: %+v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + + data, err := h.Sessions.Get(r) + if err != nil { + if errors.Is(err, session.KeyNotFoundError) { + logger.Infof("session/refresh: getting session: %+v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + + logger.Warnf("session/refresh: getting session: %+v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + data, err = h.Sessions.Refresh(r, key, data) + if err != nil { + logger.Warnf("session/refresh: refreshing: %+v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(data.Metadata.Verbose()) + if err != nil { + logger.Warnf("session/refresh: marshalling metadata: %+v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } +} diff --git a/pkg/handler/handler_test.go b/pkg/handler/handler_test.go index 9637a65..7949c8e 100644 --- a/pkg/handler/handler_test.go +++ b/pkg/handler/handler_test.go @@ -2,6 +2,7 @@ package handler_test import ( "encoding/base64" + "encoding/json" "errors" "fmt" "io" @@ -10,12 +11,14 @@ import ( "net/url" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/nais/wonderwall/pkg/cookie" urlpkg "github.com/nais/wonderwall/pkg/handler/url" "github.com/nais/wonderwall/pkg/mock" + "github.com/nais/wonderwall/pkg/session" ) func TestHandler_Login(t *testing.T) { @@ -122,11 +125,9 @@ func TestHandler_FrontChannelLogout(t *testing.T) { return data.ExternalSessionID } - frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL) + frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/frontchannel") assert.NoError(t, err) - frontchannelLogoutURL.Path = "/oauth2/logout/frontchannel" - req := idp.GetRequest(frontchannelLogoutURL.String()) values := url.Values{} @@ -154,6 +155,161 @@ func TestHandler_SessionStateRequired(t *testing.T) { assert.NotEmpty(t, sessionState) } +func TestHandler_SessionInfo(t *testing.T) { + cfg := mock.Config() + cfg.Session.Refresh = true + + idp := mock.NewIdentityProvider(cfg) + idp.ProviderHandler.TokenDuration = 5 * time.Minute + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + resp := sessionInfo(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var data session.MetadataVerbose + err := json.Unmarshal([]byte(resp.Body), &data) + assert.NoError(t, err) + + allowedSkew := 5 * time.Second + assert.WithinDuration(t, time.Now(), data.SessionCreatedAt, allowedSkew) + assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.SessionEndsAt, allowedSkew) + assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.TokensExpireAt, allowedSkew) + assert.WithinDuration(t, time.Now(), data.TokensRefreshedAt, allowedSkew) + + sessionEndDuration := time.Duration(data.SessionEndsInSeconds) * time.Second + // 1 second < time until session ends <= configured max session lifetime + assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime) + assert.Greater(t, sessionEndDuration, time.Second) + + tokenExpiryDuration := time.Duration(data.TokensExpireInSeconds) * time.Second + // 1 second < time until token expires <= max duration for tokens from IDP + assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration) + assert.Greater(t, tokenExpiryDuration, time.Second) + + // 1 second < next token refresh <= seconds until token expires + assert.LessOrEqual(t, data.TokensNextRefreshInSeconds, data.TokensExpireInSeconds) + assert.Greater(t, data.TokensNextRefreshInSeconds, int64(1)) + + assert.True(t, data.TokensRefreshCooldown) + // 1 second < refresh cooldown <= minimum refresh interval + assert.LessOrEqual(t, data.TokensRefreshCooldownSeconds, session.RefreshMinInterval) + assert.Greater(t, data.TokensRefreshCooldownSeconds, int64(1)) +} + +func TestHandler_SessionInfo_Disabled(t *testing.T) { + cfg := mock.Config() + cfg.Session.Refresh = false + + idp := mock.NewIdentityProvider(cfg) + idp.ProviderHandler.TokenDuration = 5 * time.Second + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + resp := sessionInfo(t, idp, rpClient) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) +} + +func TestHandler_SessionRefresh(t *testing.T) { + cfg := mock.Config() + cfg.Session.Refresh = true + + idp := mock.NewIdentityProvider(cfg) + idp.ProviderHandler.TokenDuration = 5 * time.Second + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + // get initial session info + resp := sessionInfo(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var data session.MetadataVerbose + err := json.Unmarshal([]byte(resp.Body), &data) + assert.NoError(t, err) + + // wait until refresh cooldown has reached zero before refresh + func() { + timeout := time.After(5 * time.Second) + ticker := time.Tick(500 * time.Millisecond) + for { + select { + case <-timeout: + assert.Fail(t, "refresh cooldown timer exceeded timeout") + case <-ticker: + resp := sessionInfo(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var temp session.MetadataVerbose + err = json.Unmarshal([]byte(resp.Body), &temp) + assert.NoError(t, err) + + if !temp.TokensRefreshCooldown { + return + } + } + } + }() + + resp = sessionRefresh(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var refreshedData session.MetadataVerbose + err = json.Unmarshal([]byte(resp.Body), &refreshedData) + assert.NoError(t, err) + + // session create and end times should be unchanged + assert.WithinDuration(t, data.SessionCreatedAt, refreshedData.SessionCreatedAt, 0) + assert.WithinDuration(t, data.SessionEndsAt, refreshedData.SessionEndsAt, 0) + + // token expiration and refresh times should be later than before + assert.True(t, refreshedData.TokensExpireAt.After(data.TokensExpireAt)) + assert.True(t, refreshedData.TokensRefreshedAt.After(data.TokensRefreshedAt)) + + allowedSkew := 5 * time.Second + assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), refreshedData.TokensExpireAt, allowedSkew) + assert.WithinDuration(t, time.Now(), refreshedData.TokensRefreshedAt, allowedSkew) + + sessionEndDuration := time.Duration(refreshedData.SessionEndsInSeconds) * time.Second + // 1 second < time until session ends <= configured max session lifetime + assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime) + assert.Greater(t, sessionEndDuration, time.Second) + + tokenExpiryDuration := time.Duration(refreshedData.TokensExpireInSeconds) * time.Second + // 1 second < time until token expires <= max duration for tokens from IDP + assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration) + assert.Greater(t, tokenExpiryDuration, time.Second) + + // 1 second < next token refresh <= seconds until token expires + assert.LessOrEqual(t, refreshedData.TokensNextRefreshInSeconds, refreshedData.TokensExpireInSeconds) + assert.Greater(t, refreshedData.TokensNextRefreshInSeconds, int64(1)) + + assert.True(t, refreshedData.TokensRefreshCooldown) + // 1 second < refresh cooldown <= minimum refresh interval + assert.LessOrEqual(t, refreshedData.TokensRefreshCooldownSeconds, session.RefreshMinInterval) + assert.Greater(t, refreshedData.TokensRefreshCooldownSeconds, int64(1)) +} + +func TestHandler_SessionRefresh_Disabled(t *testing.T) { + cfg := mock.Config() + cfg.Session.Refresh = false + + idp := mock.NewIdentityProvider(cfg) + idp.ProviderHandler.TokenDuration = 5 * time.Second + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + resp := sessionRefresh(t, idp, rpClient) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) +} + func TestHandler_Default(t *testing.T) { up := newUpstream(t) defer up.Server.Close() @@ -413,6 +569,20 @@ func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) { assert.Nil(t, sessionCookie) } +func sessionInfo(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) response { + sessionInfoURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/session") + assert.NoError(t, err) + + return get(t, rpClient, sessionInfoURL.String()) +} + +func sessionRefresh(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) response { + sessionRefreshURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/session/refresh") + assert.NoError(t, err) + + return get(t, rpClient, sessionRefreshURL.String()) +} + type response struct { Body string Location *url.URL diff --git a/pkg/mock/config.go b/pkg/mock/config.go index c062112..ffad3b3 100644 --- a/pkg/mock/config.go +++ b/pkg/mock/config.go @@ -23,7 +23,9 @@ func Config() *config.Config { Scopes: []string{"some-scope"}, UILocales: "nb", }, - SessionMaxLifetime: time.Hour, + Session: config.Session{ + MaxLifetime: time.Hour, + }, } } diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index 5a8760a..0319730 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -20,11 +20,11 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/crypto" handlerpkg "github.com/nais/wonderwall/pkg/handler" + "github.com/nais/wonderwall/pkg/openid" openidclient "github.com/nais/wonderwall/pkg/openid/client" openidconfig "github.com/nais/wonderwall/pkg/openid/config" scopespkg "github.com/nais/wonderwall/pkg/openid/scopes" "github.com/nais/wonderwall/pkg/router" - "github.com/nais/wonderwall/pkg/session" ) type IdentityProvider struct { @@ -82,13 +82,9 @@ func NewIdentityProvider(cfg *config.Config) *IdentityProvider { openidConfig.TestProvider.SetTokenEndpoint(server.URL + "/token") crypter := crypto.NewCrypter([]byte(cfg.EncryptionKey)) - sessionHandler, err := session.NewHandler(cfg, openidConfig, crypter) - if err != nil { - panic(err) - } ctx, cancel := context.WithCancel(context.Background()) - rpHandler, err := handlerpkg.NewHandler(ctx, cfg, openidConfig, crypter, sessionHandler) + rpHandler, err := handlerpkg.NewHandler(ctx, cfg, openidConfig, crypter) if err != nil { panic(err) } @@ -98,7 +94,6 @@ func NewIdentityProvider(cfg *config.Config) *IdentityProvider { // reconfigure client after Relying Party server is started openidConfig.TestClient.SetIngresses(rpServer.URL) - rpHandler.Client = openidclient.NewClient(openidConfig) return &IdentityProvider{ cancelFunc: cancel, @@ -122,18 +117,22 @@ func identityProviderRouter(ip *IdentityProviderHandler) chi.Router { } type IdentityProviderHandler struct { - Codes map[string]*AuthorizeRequest - Config openidconfig.Config - Provider *TestProvider - Sessions map[string]string + Codes map[string]*AuthorizeRequest + Config openidconfig.Config + Provider *TestProvider + Sessions map[string]string + RefreshTokens map[string]*RefreshTokenData + TokenDuration time.Duration } func newIdentityProviderHandler(provider *TestProvider, cfg openidconfig.Config) *IdentityProviderHandler { return &IdentityProviderHandler{ - Codes: make(map[string]*AuthorizeRequest), - Config: cfg, - Provider: provider, - Sessions: make(map[string]string), + Codes: make(map[string]*AuthorizeRequest), + Config: cfg, + Provider: provider, + Sessions: make(map[string]string), + RefreshTokens: make(map[string]*RefreshTokenData), + TokenDuration: time.Minute, } } @@ -147,6 +146,13 @@ type AuthorizeRequest struct { SessionID string } +type RefreshTokenData struct { + ClientID string + RefreshToken string + OriginalIDToken jwt.Token + SessionID string +} + type tokenResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` @@ -299,8 +305,23 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) return } - code := r.PostForm.Get("code") + grantType := r.PostForm.Get(openid.GrantType) + switch grantType { + case "authorization_code": + ip.TokenCodeGrant(w, r) + return + case "refresh_token": + ip.RefreshTokenGrant(w, r) + return + default: + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("unsupported grant_type: " + grantType)) + return + } +} +func (ip *IdentityProviderHandler) TokenCodeGrant(w http.ResponseWriter, r *http.Request) { + code := r.PostForm.Get("code") if len(code) == 0 { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("missing code")) @@ -314,23 +335,9 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) return } - clientID := r.PostForm.Get("client_id") - if len(clientID) == 0 { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("missing client_id")) - return - } - - if auth.ClientID != clientID { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("client_id does not match client_id used to acquire code")) - return - } - - clientAssertion := r.PostForm.Get("client_assertion") - if len(clientID) == 0 { - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("missing client_assertion")) + err := ip.validateClientAuthentication(w, r, auth.ClientID) + if err != nil { + w.Write([]byte(err.Error())) return } @@ -353,34 +360,6 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) return } - clientJwk := ip.Config.Client().ClientJWK() - clientJwkSet := jwk.NewSet() - clientJwkSet.AddKey(clientJwk) - publicClientJwkSet, err := jwk.PublicSetOf(clientJwkSet) - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("failed to create public client jwk set")) - return - } - - opts := []jwt.ParseOption{ - jwt.WithValidate(true), - jwt.WithKeySet(publicClientJwkSet), - jwt.WithIssuer(ip.Config.Client().ClientID()), - jwt.WithSubject(ip.Config.Client().ClientID()), - jwt.WithAudience(ip.Config.Provider().Issuer()), - } - _, err = jwt.Parse([]byte(clientAssertion), opts...) - if err != nil { - w.WriteHeader(http.StatusUnauthorized) - v := url.Values{} - v.Set("error", "Unauthenticated") - v.Set("error_description", "invalid client assertion") - v.Encode() - w.Write([]byte(fmt.Sprintf(v.Encode()+"%+v", err))) - return - } - codeVerifier := r.PostForm.Get("code_verifier") if len(codeVerifier) == 0 { w.WriteHeader(http.StatusBadRequest) @@ -396,9 +375,8 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) return } - expires := int64(1200) iat := time.Now().Truncate(time.Second) - exp := iat.Add(time.Duration(expires) * time.Second) + exp := iat.Add(ip.TokenDuration) sub := uuid.New().String() accessToken := jwt.New() @@ -438,12 +416,21 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) return } + refreshToken := code + "some-refresh-token" + token := &tokenResponse{ AccessToken: signedAccessToken, TokenType: "Bearer", IDToken: signedIdToken, - RefreshToken: code + "some-refresh-token", - ExpiresIn: expires, + RefreshToken: refreshToken, + ExpiresIn: int64(ip.TokenDuration.Seconds()), + } + + ip.RefreshTokens[refreshToken] = &RefreshTokenData{ + ClientID: auth.ClientID, + RefreshToken: refreshToken, + OriginalIDToken: idToken, + SessionID: auth.SessionID, } w.Header().Set("content-type", "application/json") @@ -451,6 +438,116 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) json.NewEncoder(w).Encode(token) } +func (ip *IdentityProviderHandler) RefreshTokenGrant(w http.ResponseWriter, r *http.Request) { + refreshToken := r.PostForm.Get("refresh_token") + if len(refreshToken) == 0 { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("missing refresh_token")) + return + } + + data, ok := ip.RefreshTokens[refreshToken] + if !ok { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("no matching refresh_token")) + return + } + + err := ip.validateClientAuthentication(w, r, data.ClientID) + if err != nil { + w.Write([]byte(err.Error())) + return + } + + iat := time.Now().Truncate(time.Second) + exp := iat.Add(ip.TokenDuration) + sub := data.OriginalIDToken.Subject() + + accessToken := jwt.New() + accessToken.Set("sub", sub) + accessToken.Set("iss", ip.Config.Provider().Issuer()) + accessToken.Set("iat", iat.Unix()) + accessToken.Set("exp", exp.Unix()) + accessToken.Set("jti", uuid.NewString()) + signedAccessToken, err := ip.signToken(accessToken) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("could not sign access token: " + err.Error())) + return + } + + // remove provided refresh_token as it is now used + delete(ip.RefreshTokens, refreshToken) + + // generate and store a new refresh_token + refreshToken = uuid.NewString() + "some-new-refresh-token" + + token := &tokenResponse{ + AccessToken: signedAccessToken, + TokenType: "Bearer", + RefreshToken: refreshToken, + ExpiresIn: int64(ip.TokenDuration.Seconds()), + } + + ip.RefreshTokens[refreshToken] = &RefreshTokenData{ + ClientID: data.ClientID, + RefreshToken: refreshToken, + OriginalIDToken: data.OriginalIDToken, + SessionID: data.SessionID, + } + + w.Header().Set("content-type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(token) +} + +func (ip *IdentityProviderHandler) validateClientAuthentication(w http.ResponseWriter, r *http.Request, expectedClientID string) error { + clientID := r.PostForm.Get("client_id") + if len(clientID) == 0 { + w.WriteHeader(http.StatusBadRequest) + return fmt.Errorf("missing client_id") + } + + if expectedClientID != clientID { + w.WriteHeader(http.StatusBadRequest) + return fmt.Errorf("client_id does not match client_id for original authorization") + } + + clientAssertion := r.PostForm.Get("client_assertion") + if len(clientAssertion) == 0 { + w.WriteHeader(http.StatusBadRequest) + return fmt.Errorf("missing client_assertion") + } + + clientJwk := ip.Config.Client().ClientJWK() + clientJwkSet := jwk.NewSet() + clientJwkSet.AddKey(clientJwk) + publicClientJwkSet, err := jwk.PublicSetOf(clientJwkSet) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return fmt.Errorf("failed to create public client jwk set") + } + + opts := []jwt.ParseOption{ + jwt.WithValidate(true), + jwt.WithKeySet(publicClientJwkSet), + jwt.WithIssuer(ip.Config.Client().ClientID()), + jwt.WithSubject(ip.Config.Client().ClientID()), + jwt.WithAudience(ip.Config.Provider().Issuer()), + } + _, err = jwt.Parse([]byte(clientAssertion), opts...) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + v := url.Values{} + v.Set("error", "Unauthenticated") + v.Set("error_description", "invalid client assertion") + v.Encode() + return fmt.Errorf("%s: %+v", v.Encode(), err) + } + + return nil +} + func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() postLogoutRedirectURI := query.Get("post_logout_redirect_uri") diff --git a/pkg/openid/client/client.go b/pkg/openid/client/client.go index f54b08c..ad77895 100644 --- a/pkg/openid/client/client.go +++ b/pkg/openid/client/client.go @@ -2,8 +2,12 @@ package client import ( "context" + "encoding/json" "fmt" + "io" "net/http" + "net/url" + "strings" "time" "github.com/google/uuid" @@ -20,6 +24,8 @@ type Client interface { config() openidconfig.Config oAuth2Config() *oauth2.Config + SetHttpClient(c *http.Client) + Login(r *http.Request, loginstatus loginstatus.Loginstatus) (Login, error) LoginCallback(r *http.Request, p provider.Provider, cookie *openid.LoginCookie) (LoginCallback, error) Logout(r *http.Request) (Logout, error) @@ -28,11 +34,12 @@ type Client interface { AuthCodeGrant(ctx context.Context, code string, opts []oauth2.AuthCodeOption) (*oauth2.Token, error) MakeAssertion(expiration time.Duration) (string, error) - RefreshGrant(r *http.Request) error + RefreshGrant(ctx context.Context, refreshToken string) (*openid.TokenResponse, error) } type client struct { cfg openidconfig.Config + httpClient *http.Client oauth2Config *oauth2.Config } @@ -49,6 +56,7 @@ func NewClient(cfg openidconfig.Config) Client { return &client{ cfg: cfg, + httpClient: http.DefaultClient, oauth2Config: oauth2Config, } } @@ -61,6 +69,10 @@ func (c *client) oAuth2Config() *oauth2.Config { return c.oauth2Config } +func (c *client) SetHttpClient(httpClient *http.Client) { + c.httpClient = httpClient +} + func (c *client) Login(r *http.Request, loginstatus loginstatus.Loginstatus) (Login, error) { login, err := NewLogin(c, r, loginstatus) if err != nil { @@ -132,7 +144,50 @@ func (c *client) MakeAssertion(expiration time.Duration) (string, error) { return string(encoded), nil } -func (c *client) RefreshGrant(r *http.Request) error { - //TODO implement me - panic("implement me") +func (c *client) RefreshGrant(ctx context.Context, refreshToken string) (*openid.TokenResponse, error) { + assertion, err := c.MakeAssertion(30 * time.Second) + if err != nil { + return nil, fmt.Errorf("creating client assertion: %w", err) + } + + v := url.Values{} + v.Set(openid.GrantType, openid.RefreshTokenValue) + v.Set(openid.RefreshToken, refreshToken) + v.Set(openid.ClientID, c.config().Client().ClientID()) + v.Set(openid.ClientAssertion, assertion) + v.Set(openid.ClientAssertionType, openid.ClientAssertionTypeJwtBearer) + + r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.config().Provider().TokenEndpoint(), strings.NewReader(v.Encode())) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + r.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.httpClient.Do(r) + if err != nil { + return nil, fmt.Errorf("performing request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading server response: %w", err) + } + + if resp.StatusCode >= 400 && resp.StatusCode < 500 { + var errorResponse openid.TokenErrorResponse + if err := json.Unmarshal(body, &errorResponse); err != nil { + return nil, fmt.Errorf("client error: HTTP %d: unmarshalling error response: %w", resp.StatusCode, err) + } + return nil, fmt.Errorf("client error: HTTP %d: %s: %s", resp.StatusCode, errorResponse.Error, errorResponse.ErrorDescription) + } else if resp.StatusCode >= 500 { + return nil, fmt.Errorf("server error: HTTP %d: %s", resp.StatusCode, body) + } + + var tokenResponse openid.TokenResponse + if err := json.Unmarshal(body, &tokenResponse); err != nil { + return nil, fmt.Errorf("unmarshalling token response: %w", err) + } + + return &tokenResponse, nil } diff --git a/pkg/openid/client/login_callback.go b/pkg/openid/client/login_callback.go index b9c31a6..1b57b5e 100644 --- a/pkg/openid/client/login_callback.go +++ b/pkg/openid/client/login_callback.go @@ -13,10 +13,6 @@ import ( "github.com/nais/wonderwall/pkg/openid/provider" ) -const ( - ClientAssertionTypeJwtBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" -) - type LoginCallback interface { IdentityProviderError() error StateMismatchError() error @@ -79,7 +75,7 @@ func (in *loginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, erro opts := []oauth2.AuthCodeOption{ oauth2.SetAuthURLParam(openid.CodeVerifier, in.cookie.CodeVerifier), oauth2.SetAuthURLParam(openid.ClientAssertion, clientAssertion), - oauth2.SetAuthURLParam(openid.ClientAssertionType, ClientAssertionTypeJwtBearer), + oauth2.SetAuthURLParam(openid.ClientAssertionType, openid.ClientAssertionTypeJwtBearer), oauth2.SetAuthURLParam(openid.RedirectURI, in.cookie.RedirectURI), } diff --git a/pkg/openid/config/client.go b/pkg/openid/config/client.go index 30e063b..a3474e2 100644 --- a/pkg/openid/config/client.go +++ b/pkg/openid/config/client.go @@ -132,6 +132,7 @@ func (in *client) Azure() Client { func (in *azure) Scopes() scopes.Scopes { return scopes.DefaultScopes(). WithAzureScope(in.OpenID.ClientID). + WithOfflineAccess(). WithAdditional(in.OpenID.Scopes...) } diff --git a/pkg/openid/params.go b/pkg/openid/params.go index 24d853d..d2db308 100644 --- a/pkg/openid/params.go +++ b/pkg/openid/params.go @@ -4,12 +4,14 @@ const ( ACRValues = "acr_values" ClientAssertion = "client_assertion" ClientAssertionType = "client_assertion_type" + ClientID = "client_id" CodeChallenge = "code_challenge" CodeChallengeMethod = "code_challenge_method" Code = "code" CodeVerifier = "code_verifier" Error = "error" ErrorDescription = "error_description" + GrantType = "grant_type" IDTokenHint = "id_token_hint" Nonce = "nonce" PostLogoutRedirectURI = "post_logout_redirect_uri" @@ -17,6 +19,7 @@ const ( Sid = "sid" State = "state" RedirectURI = "redirect_uri" + RefreshToken = "refresh_token" Resource = "resource" ResponseMode = "response_mode" UILocales = "ui_locales" diff --git a/pkg/openid/params_values.go b/pkg/openid/params_values.go new file mode 100644 index 0000000..5440ea6 --- /dev/null +++ b/pkg/openid/params_values.go @@ -0,0 +1,6 @@ +package openid + +const ( + ClientAssertionTypeJwtBearer = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + RefreshTokenValue = "refresh_token" +) diff --git a/pkg/openid/response.go b/pkg/openid/response.go new file mode 100644 index 0000000..d78626c --- /dev/null +++ b/pkg/openid/response.go @@ -0,0 +1,17 @@ +package openid + +// TokenResponse is the struct representing the HTTP response from OpenID Connect providers returning a token in +// JSON form. +type TokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` +} + +// TokenErrorResponse is the struct representing the HTTP error response returned from OpenID Connect providers +// in JSON form. +type TokenErrorResponse struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} diff --git a/pkg/openid/scopes/scopes.go b/pkg/openid/scopes/scopes.go index f52bc0e..38e0aec 100644 --- a/pkg/openid/scopes/scopes.go +++ b/pkg/openid/scopes/scopes.go @@ -7,6 +7,7 @@ import ( const ( OpenID = "openid" + OfflineAccess = "offline_access" AzureAPITemplate = "api://%s/.default" ) @@ -24,6 +25,10 @@ func (s Scopes) WithAzureScope(clientID string) Scopes { return append(s, fmt.Sprintf(AzureAPITemplate, clientID)) } +func (s Scopes) WithOfflineAccess() Scopes { + return append(s, OfflineAccess) +} + func DefaultScopes() Scopes { return []string{OpenID} } diff --git a/pkg/router/paths/paths.go b/pkg/router/paths/paths.go index 91482d3..8a22e61 100644 --- a/pkg/router/paths/paths.go +++ b/pkg/router/paths/paths.go @@ -7,4 +7,6 @@ const ( Logout = "/logout" LogoutCallback = "/logout/callback" FrontChannelLogout = "/logout/frontchannel" + Session = "/session" + SessionRefresh = "/session/refresh" ) diff --git a/pkg/router/router.go b/pkg/router/router.go index 67e0610..1b3036d 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -36,6 +36,11 @@ func New(handler *handler.Handler) chi.Router { r.Get(paths.Logout, handler.Logout) r.Get(paths.FrontChannelLogout, handler.FrontChannelLogout) r.Get(paths.LogoutCallback, handler.LogoutCallback) + + if handler.Config.Session.Refresh { + r.Get(paths.Session, handler.SessionInfo) + r.Get(paths.SessionRefresh, handler.SessionRefresh) + } }) } }) diff --git a/pkg/session/data.go b/pkg/session/data.go new file mode 100644 index 0000000..4d32a71 --- /dev/null +++ b/pkg/session/data.go @@ -0,0 +1,202 @@ +package session + +import ( + "encoding" + "encoding/base64" + "encoding/json" + "time" + + "github.com/nais/wonderwall/pkg/crypto" + "github.com/nais/wonderwall/pkg/openid" +) + +const ( + RefreshMinInterval = 1 * time.Minute + RefreshLeeway = 5 * time.Minute +) + +type EncryptedData struct { + Data string `json:"data"` +} + +var _ encoding.BinaryMarshaler = &EncryptedData{} +var _ encoding.BinaryUnmarshaler = &EncryptedData{} + +func (in *EncryptedData) MarshalBinary() ([]byte, error) { + return json.Marshal(in) +} + +func (in *EncryptedData) UnmarshalBinary(bytes []byte) error { + return json.Unmarshal(bytes, in) +} + +func (in *EncryptedData) Decrypt(crypter crypto.Crypter) (*Data, error) { + ciphertext, err := base64.StdEncoding.DecodeString(in.Data) + if err != nil { + return nil, err + } + + rawData, err := crypter.Decrypt(ciphertext) + if err != nil { + return nil, err + } + + var data Data + err = json.Unmarshal(rawData, &data) + if err != nil { + return nil, err + } + + return &data, nil +} + +type Data struct { + ExternalSessionID string `json:"external_session_id"` + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + IDTokenJwtID string `json:"id_token_jwt_id"` + Metadata Metadata `json:"metadata"` +} + +func NewData(externalSessionID string, tokens *openid.Tokens, metadata *Metadata) *Data { + data := &Data{ + ExternalSessionID: externalSessionID, + AccessToken: tokens.AccessToken, + IDToken: tokens.IDToken.GetSerialized(), + IDTokenJwtID: tokens.IDToken.GetJwtID(), + RefreshToken: tokens.RefreshToken, + } + + if metadata != nil { + data.Metadata = *metadata + } + + return data +} + +func (in *Data) Encrypt(crypter crypto.Crypter) (*EncryptedData, error) { + bytes, err := json.Marshal(in) + if err != nil { + return nil, err + } + + ciphertext, err := crypter.Encrypt(bytes) + if err != nil { + return nil, err + } + + return &EncryptedData{ + Data: base64.StdEncoding.EncodeToString(ciphertext), + }, nil +} + +func (in *Data) HasAccessToken() bool { + return len(in.AccessToken) > 0 +} + +func (in *Data) HasRefreshToken() bool { + return len(in.RefreshToken) > 0 +} + +type Metadata struct { + // SessionCreatedAt is the time when the session was created. + SessionCreatedAt time.Time `json:"session_created_at"` + // SessionEndsAt is the time when the session will end, i.e. the absolute lifetime/time-to-live for the session. + SessionEndsAt time.Time `json:"session_ends_at"` + // TokensExpireAt is the time when the tokens within the session expires. + TokensExpireAt time.Time `json:"tokens_expire_at"` + // TokensRefreshedAt is the time when the tokens within the session was refreshed. + TokensRefreshedAt time.Time `json:"tokens_refreshed_at"` +} + +func NewMetadata(expiresIn time.Duration, endsIn time.Duration) *Metadata { + now := time.Now() + return &Metadata{ + SessionCreatedAt: now, + SessionEndsAt: now.Add(endsIn), + TokensRefreshedAt: now, + TokensExpireAt: now.Add(expiresIn), + } +} + +func (in *Metadata) NextRefresh() time.Time { + // subtract the leeway to ensure that we refresh before expiry + next := in.TokensExpireAt.Add(-RefreshLeeway) + + // try to refresh at the first opportunity if the next refresh is in the past + if next.Before(time.Now()) { + return in.RefreshCooldown() + } + + return next +} + +func (in *Metadata) Refresh(nextExpirySeconds int64) { + now := time.Now() + in.TokensRefreshedAt = now + in.TokensExpireAt = now.Add(time.Duration(nextExpirySeconds) * time.Second) +} + +func (in *Metadata) RefreshCooldown() time.Time { + refreshed := in.TokensRefreshedAt + tokenLifetime := in.TokenLifetime() + + // if token lifetime is less than the minimum refresh interval * 2, we'll allow refreshes at the token half-life + if tokenLifetime <= RefreshMinInterval*2 { + return refreshed.Add(tokenLifetime / 2) + } + + return refreshed.Add(RefreshMinInterval) +} + +func (in *Metadata) RefreshOnCooldown() bool { + return time.Now().Before(in.RefreshCooldown()) +} + +func (in *Metadata) ShouldRefresh() bool { + if in.RefreshOnCooldown() { + return false + } + + return time.Now().After(in.NextRefresh()) +} + +func (in *Metadata) TokenLifetime() time.Duration { + return in.TokensExpireAt.Sub(in.TokensRefreshedAt) +} + +func (in *Metadata) Verbose() MetadataVerbose { + now := time.Now() + + expireTime := in.TokensExpireAt + endTime := in.SessionEndsAt + nextRefreshTime := in.NextRefresh() + + return MetadataVerbose{ + Metadata: *in, + SessionEndsInSeconds: toSeconds(endTime.Sub(now)), + TokensExpireInSeconds: toSeconds(expireTime.Sub(now)), + TokensNextRefreshInSeconds: toSeconds(nextRefreshTime.Sub(now)), + TokensRefreshCooldown: in.RefreshOnCooldown(), + TokensRefreshCooldownSeconds: toSeconds(in.RefreshCooldown().Sub(now)), + } +} + +type MetadataVerbose struct { + Metadata + SessionEndsInSeconds int64 `json:"session_ends_in_seconds"` + TokensExpireInSeconds int64 `json:"tokens_expire_in_seconds"` + TokensNextRefreshInSeconds int64 `json:"tokens_next_refresh_in_seconds"` + TokensRefreshCooldown bool `json:"tokens_refresh_cooldown"` + TokensRefreshCooldownSeconds int64 `json:"tokens_refresh_cooldown_seconds"` +} + +func toSeconds(d time.Duration) int64 { + i := int64(d.Seconds()) + if i <= 0 { + return 0 + } + + return i +} diff --git a/pkg/session/data_test.go b/pkg/session/data_test.go new file mode 100644 index 0000000..a1278cf --- /dev/null +++ b/pkg/session/data_test.go @@ -0,0 +1,41 @@ +package session_test + +import ( + "testing" +) + +func TestData_HasAccessToken(t *testing.T) { + // TODO +} + +func TestData_HasRefreshToken(t *testing.T) { + // TODO +} + +func TestMetadata_NextRefresh(t *testing.T) { + // TODO +} + +func TestMetadata_Refresh(t *testing.T) { + // TODO +} + +func TestMetadata_RefreshCooldown(t *testing.T) { + // TODO +} + +func TestMetadata_RefreshOnCooldown(t *testing.T) { + // TODO +} + +func TestMetadata_ShouldRefresh(t *testing.T) { + // TODO +} + +func TestMetadata_TokenAge(t *testing.T) { + // TODO +} + +func TestMetadata_Verbose(t *testing.T) { + // TODO +} diff --git a/pkg/session/handler.go b/pkg/session/handler.go index b9e4ea3..dc08ecb 100644 --- a/pkg/session/handler.go +++ b/pkg/session/handler.go @@ -13,54 +13,59 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" "github.com/nais/wonderwall/pkg/crypto" + mw "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/openid" + openidclient "github.com/nais/wonderwall/pkg/openid/client" openidconfig "github.com/nais/wonderwall/pkg/openid/config" retrypkg "github.com/nais/wonderwall/pkg/retry" "github.com/nais/wonderwall/pkg/strings" ) +var ( + CookieNotFoundError = errors.New("cookie not found") +) + type Handler struct { - cfg *config.Config - openidCfg openidconfig.Config - crypter crypto.Crypter - store Store + client openidclient.Client + crypter crypto.Crypter + openidCfg openidconfig.Config + refreshEnabled bool + store Store } -func NewHandler(cfg *config.Config, openidCfg openidconfig.Config, crypter crypto.Crypter) (*Handler, error) { +func NewHandler(cfg *config.Config, openidCfg openidconfig.Config, crypter crypto.Crypter, openidClient openidclient.Client) (*Handler, error) { store, err := NewStore(cfg) if err != nil { return nil, err } return &Handler{ - cfg: cfg, - crypter: crypter, - openidCfg: openidCfg, - store: store, + crypter: crypter, + client: openidClient, + openidCfg: openidCfg, + store: store, + refreshEnabled: cfg.Session.Refresh, }, nil } // Create creates and stores a session in the Store, and returns the session's key. -func (h *Handler) Create(r *http.Request, tokens *openid.Tokens, expiresIn time.Duration) (string, error) { +func (h *Handler) Create(r *http.Request, tokens *openid.Tokens, sessionLifetime time.Duration) (string, error) { externalSessionID, err := h.IDOrGenerate(r, tokens) if err != nil { return "", fmt.Errorf("generating session ID: %w", err) } key := h.Key(externalSessionID) - metadata := NewMetadata(time.Now().Add(expiresIn)) + tokenExpiresIn := tokens.Expiry.Sub(time.Now()) + metadata := NewMetadata(tokenExpiresIn, sessionLifetime) encrypted, err := NewData(externalSessionID, tokens, metadata).Encrypt(h.crypter) if err != nil { return "", fmt.Errorf("encrypting session data: %w", err) } retryable := func(ctx context.Context) error { - err = h.store.Write(r.Context(), key, encrypted, expiresIn) - if err != nil { - return retry.RetryableError(err) - } - - return nil + err = h.store.Write(r.Context(), key, encrypted, sessionLifetime) + return retry.RetryableError(err) } if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { @@ -99,21 +104,12 @@ func (h *Handler) destroyForKey(r *http.Request, key string) error { // Get returns the session data for a given http.Request, matching by the session cookie. func (h *Handler) Get(r *http.Request) (*Data, error) { - key, err := cookie.GetDecrypted(r, cookie.Session, h.crypter) + key, err := h.GetKey(r) if err != nil { return nil, fmt.Errorf("no session cookie: %w", err) } - sessionData, err := h.GetForKey(r, key) - if err == nil { - return sessionData, nil - } - - if errors.Is(err, KeyNotFoundError) { - return nil, fmt.Errorf("session not found: %w", err) - } - - return nil, err + return h.GetForKey(r, key) } // GetForID returns the session data for a given session ID. @@ -152,6 +148,42 @@ func (h *Handler) GetForKey(r *http.Request, key string) (*Data, error) { return sessionData, nil } +// GetKey extracts the session Key from the session cookie found in the request, if any. +func (h *Handler) GetKey(r *http.Request) (string, error) { + key, err := cookie.GetDecrypted(r, cookie.Session, h.crypter) + if err != nil { + return "", fmt.Errorf("%w: %+v", CookieNotFoundError, err) + } + + return key, nil +} + +// GetOrRefresh returns the session data, performing refreshes if enabled and necessary. +func (h *Handler) GetOrRefresh(r *http.Request) (*Data, error) { + key, err := h.GetKey(r) + if err != nil { + return nil, err + } + + sessionData, err := h.GetForKey(r, key) + if err != nil { + return nil, err + } + + if !h.refreshEnabled || !sessionData.HasRefreshToken() || !sessionData.Metadata.ShouldRefresh() { + return sessionData, nil + } + + refreshed, err := h.Refresh(r, key, sessionData) + if err != nil { + mw.LogEntryFrom(r).Warnf("session: could not refresh tokens, falling back to existing token: %+v", err) + } else { + sessionData = refreshed + } + + return sessionData, nil +} + // IDOrGenerate returns the session ID, derived from the given request or id_token; e.g. `sid` or `session_state`. // If none are present, a generated ID is returned. func (h *Handler) IDOrGenerate(r *http.Request, tokens *openid.Tokens) (string, error) { @@ -172,6 +204,48 @@ func (h *Handler) Key(sessionID string) string { return fmt.Sprintf("%s:%s:%s", provider.Name(), client.ClientID(), sessionID) } +// Refresh refreshes the user's session and returns the updated session data. +func (h *Handler) Refresh(r *http.Request, key string, data *Data) (*Data, error) { + if !h.refreshEnabled || !data.HasRefreshToken() || data.Metadata.RefreshOnCooldown() { + return data, nil + } + + logger := mw.LogEntryFrom(r) + logger.Info("session: refreshing token...") + + var resp *openid.TokenResponse + var err error + + refresh := func(ctx context.Context) error { + resp, err = h.client.RefreshGrant(r.Context(), data.RefreshToken) + return retry.RetryableError(err) + } + if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, refresh); err != nil { + return nil, fmt.Errorf("performing refresh: %w", err) + } + + data.AccessToken = resp.AccessToken + data.RefreshToken = resp.RefreshToken + data.Metadata.Refresh(resp.ExpiresIn) + + encrypted, err := data.Encrypt(h.crypter) + if err != nil { + return nil, fmt.Errorf("encrypting session data: %w", err) + } + + update := func(ctx context.Context) error { + err = h.store.Update(r.Context(), key, encrypted) + return retry.RetryableError(err) + } + + if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, update); err != nil { + return nil, fmt.Errorf("updating in store: %w", err) + } + + logger.Info("session: successfully refreshed") + return data, nil +} + func NewSessionID(cfg openidconfig.Provider, idToken *openid.IDToken, params url.Values) (string, error) { // 1. check for 'sid' claim in id_token sessionID, err := idToken.GetSidClaim() diff --git a/pkg/session/models.go b/pkg/session/models.go deleted file mode 100644 index 5010401..0000000 --- a/pkg/session/models.go +++ /dev/null @@ -1,105 +0,0 @@ -package session - -import ( - "encoding" - "encoding/base64" - "encoding/json" - "time" - - "github.com/nais/wonderwall/pkg/crypto" - "github.com/nais/wonderwall/pkg/openid" -) - -type EncryptedData struct { - Data string `json:"data"` -} - -var _ encoding.BinaryMarshaler = &EncryptedData{} -var _ encoding.BinaryUnmarshaler = &EncryptedData{} - -func (in *EncryptedData) MarshalBinary() ([]byte, error) { - return json.Marshal(in) -} - -func (in *EncryptedData) UnmarshalBinary(bytes []byte) error { - return json.Unmarshal(bytes, in) -} - -func (in *EncryptedData) Decrypt(crypter crypto.Crypter) (*Data, error) { - ciphertext, err := base64.StdEncoding.DecodeString(in.Data) - if err != nil { - return nil, err - } - - rawData, err := crypter.Decrypt(ciphertext) - if err != nil { - return nil, err - } - - var data Data - err = json.Unmarshal(rawData, &data) - if err != nil { - return nil, err - } - - return &data, nil -} - -type Data struct { - ExternalSessionID string `json:"external_session_id"` - AccessToken string `json:"access_token"` - IDToken string `json:"id_token"` - RefreshToken string `json:"refresh_token"` - IDTokenJwtID string `json:"id_token_jwt_id"` - Metadata Metadata `json:"metadata"` -} - -func NewData(externalSessionID string, tokens *openid.Tokens, metadata *Metadata) *Data { - data := &Data{ - ExternalSessionID: externalSessionID, - AccessToken: tokens.AccessToken, - IDToken: tokens.IDToken.GetSerialized(), - IDTokenJwtID: tokens.IDToken.GetJwtID(), - RefreshToken: tokens.RefreshToken, - } - - if metadata != nil { - data.Metadata = *metadata - } - - return data -} - -func (in *Data) Encrypt(crypter crypto.Crypter) (*EncryptedData, error) { - bytes, err := json.Marshal(in) - if err != nil { - return nil, err - } - - ciphertext, err := crypter.Encrypt(bytes) - if err != nil { - return nil, err - } - - return &EncryptedData{ - Data: base64.StdEncoding.EncodeToString(ciphertext), - }, nil -} - -type Metadata struct { - CreatedAt int64 `json:"created_at"` - RefreshedAt int64 `json:"refreshed_at"` - ExpiresAt int64 `json:"expires_at"` -} - -func NewMetadata(expiresAt time.Time) *Metadata { - return &Metadata{ - CreatedAt: time.Now().Unix(), - RefreshedAt: time.Now().Unix(), - ExpiresAt: expiresAt.Unix(), - } -} - -func (in *Metadata) UpdateRefreshedAt() { - in.RefreshedAt = time.Now().Unix() -} diff --git a/pkg/session/store.go b/pkg/session/store.go index ebdb67e..b20a947 100644 --- a/pkg/session/store.go +++ b/pkg/session/store.go @@ -20,6 +20,7 @@ type Store interface { Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error Read(ctx context.Context, key string) (*EncryptedData, error) Delete(ctx context.Context, keys ...string) error + Update(ctx context.Context, key string, value *EncryptedData) error } func NewStore(cfg *config.Config) (Store, error) { diff --git a/pkg/session/store_memory.go b/pkg/session/store_memory.go index 9997957..e62c33e 100644 --- a/pkg/session/store_memory.go +++ b/pkg/session/store_memory.go @@ -50,3 +50,11 @@ func (s *memorySessionStore) Delete(_ context.Context, keys ...string) error { return nil } + +func (s *memorySessionStore) Update(_ context.Context, key string, value *EncryptedData) error { + s.lock.Lock() + defer s.lock.Unlock() + + s.sessions[key] = value + return nil +} diff --git a/pkg/session/store_memory_test.go b/pkg/session/store_memory_test.go index 44d817d..98bf4e2 100644 --- a/pkg/session/store_memory_test.go +++ b/pkg/session/store_memory_test.go @@ -1,57 +1,31 @@ package session_test import ( - "context" "testing" - "time" - jwtlib "github.com/lestrrat-go/jwx/v2/jwt" - "github.com/nais/liberator/pkg/keygen" "github.com/stretchr/testify/assert" - "github.com/nais/wonderwall/pkg/crypto" - "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/session" ) func TestMemory(t *testing.T) { - key, err := keygen.Keygen(32) - assert.NoError(t, err) - crypter := crypto.NewCrypter(key) - - idToken := jwtlib.New() - idToken.Set("jti", "id-token-jti") - - accessToken := "some-access-token" - refreshToken := "some-refresh-token" - - tokens := &openid.Tokens{ - AccessToken: accessToken, - IDToken: openid.NewIDToken("id_token", idToken), - RefreshToken: refreshToken, - } - metadata := session.NewMetadata(time.Now().Add(time.Hour)) - data := session.NewData("myid", tokens, metadata) - + crypter := makeCrypter(t) + data := makeData() encryptedData, err := data.Encrypt(crypter) assert.NoError(t, err) - sess := session.NewMemory() - err = sess.Write(context.Background(), "key", encryptedData, time.Minute) - assert.NoError(t, err) + store := session.NewMemory() + key := "key" - result, err := sess.Read(context.Background(), "key") - assert.NoError(t, err) - assert.Equal(t, encryptedData, result) + write(t, store, key, encryptedData) - decrypted, err := result.Decrypt(crypter) - assert.NoError(t, err) - assert.Equal(t, data, decrypted) + decrypted := read(t, store, key, encryptedData, crypter) + decryptedEqual(t, data, decrypted) - err = sess.Delete(context.Background(), "key") + data, encryptedData = update(t, store, key, data, crypter) - result, err = sess.Read(context.Background(), "key") - assert.Error(t, err) - assert.ErrorIs(t, err, session.KeyNotFoundError) - assert.Nil(t, result) + decrypted = read(t, store, key, encryptedData, crypter) + decryptedEqual(t, data, decrypted) + + del(t, store, key) } diff --git a/pkg/session/store_redis.go b/pkg/session/store_redis.go index a475695..f6b6629 100644 --- a/pkg/session/store_redis.go +++ b/pkg/session/store_redis.go @@ -64,3 +64,14 @@ func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error { return fmt.Errorf("%w: %s", UnexpectedError, err.Error()) } + +func (s *redisSessionStore) Update(ctx context.Context, key string, value *EncryptedData) error { + err := metrics.ObserveRedisLatency(metrics.RedisOperationWrite, func() error { + return s.client.Set(ctx, key, value, redis.KeepTTL).Err() + }) + if err != nil { + return fmt.Errorf("%w: %s", UnexpectedError, err.Error()) + } + + return nil +} diff --git a/pkg/session/store_redis_test.go b/pkg/session/store_redis_test.go index 442e9d4..f282310 100644 --- a/pkg/session/store_redis_test.go +++ b/pkg/session/store_redis_test.go @@ -1,40 +1,18 @@ package session_test import ( - "context" "testing" - "time" "github.com/alicebob/miniredis/v2" "github.com/go-redis/redis/v8" - jwtlib "github.com/lestrrat-go/jwx/v2/jwt" - "github.com/nais/liberator/pkg/keygen" "github.com/stretchr/testify/assert" - "github.com/nais/wonderwall/pkg/crypto" - "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/session" ) func TestRedis(t *testing.T) { - key, err := keygen.Keygen(32) - assert.NoError(t, err) - crypter := crypto.NewCrypter(key) - - idToken := jwtlib.New() - idToken.Set("jti", "id-token-jti") - - accessToken := "some-access-token" - refreshToken := "some-refresh-token" - - tokens := &openid.Tokens{ - AccessToken: accessToken, - IDToken: openid.NewIDToken("id_token", idToken), - RefreshToken: refreshToken, - } - metadata := session.NewMetadata(time.Now().Add(time.Hour)) - data := session.NewData("myid", tokens, metadata) - + crypter := makeCrypter(t) + data := makeData() encryptedData, err := data.Encrypt(crypter) assert.NoError(t, err) @@ -49,22 +27,18 @@ func TestRedis(t *testing.T) { Addr: s.Addr(), }) - sess := session.NewRedis(client) - err = sess.Write(context.Background(), "key", encryptedData, time.Minute) - assert.NoError(t, err) + store := session.NewRedis(client) + key := "key" - result, err := sess.Read(context.Background(), "key") - assert.NoError(t, err) - assert.Equal(t, encryptedData, result) + write(t, store, key, encryptedData) - decrypted, err := result.Decrypt(crypter) - assert.NoError(t, err) - assert.Equal(t, data, decrypted) + decrypted := read(t, store, key, encryptedData, crypter) + decryptedEqual(t, data, decrypted) - err = sess.Delete(context.Background(), "key") + data, encryptedData = update(t, store, key, data, crypter) - result, err = sess.Read(context.Background(), "key") - assert.Error(t, err) - assert.ErrorIs(t, err, session.KeyNotFoundError) - assert.Nil(t, result) + decrypted = read(t, store, key, encryptedData, crypter) + decryptedEqual(t, data, decrypted) + + del(t, store, key) } diff --git a/pkg/session/store_test.go b/pkg/session/store_test.go new file mode 100644 index 0000000..f9822a4 --- /dev/null +++ b/pkg/session/store_test.go @@ -0,0 +1,91 @@ +package session_test + +import ( + "context" + "testing" + "time" + + jwtlib "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/nais/liberator/pkg/keygen" + "github.com/stretchr/testify/assert" + + "github.com/nais/wonderwall/pkg/crypto" + "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/session" +) + +func decryptedEqual(t *testing.T, expected, actual *session.Data) { + assert.Equal(t, expected.AccessToken, actual.AccessToken) + assert.Equal(t, expected.RefreshToken, actual.RefreshToken) + assert.Equal(t, expected.IDToken, actual.IDToken) + assert.Equal(t, expected.IDTokenJwtID, actual.IDTokenJwtID) + assert.Equal(t, expected.ExternalSessionID, actual.ExternalSessionID) + + assert.WithinDuration(t, expected.Metadata.SessionCreatedAt, actual.Metadata.SessionCreatedAt, 0) + assert.WithinDuration(t, expected.Metadata.SessionEndsAt, actual.Metadata.SessionEndsAt, 0) + assert.WithinDuration(t, expected.Metadata.TokensExpireAt, actual.Metadata.TokensExpireAt, 0) + assert.WithinDuration(t, expected.Metadata.TokensRefreshedAt, actual.Metadata.TokensRefreshedAt, 0) +} + +func makeCrypter(t *testing.T) crypto.Crypter { + key, err := keygen.Keygen(32) + assert.NoError(t, err) + return crypto.NewCrypter(key) +} + +func makeData() *session.Data { + idToken := jwtlib.New() + idToken.Set("jti", "id-token-jti") + + accessToken := "some-access-token" + refreshToken := "some-refresh-token" + + tokens := &openid.Tokens{ + AccessToken: accessToken, + IDToken: openid.NewIDToken("id_token", idToken), + RefreshToken: refreshToken, + } + + expiresIn := time.Hour + endsIn := time.Hour + + metadata := session.NewMetadata(expiresIn, endsIn) + return session.NewData("myid", tokens, metadata) +} + +func write(t *testing.T, store session.Store, key string, value *session.EncryptedData) { + err := store.Write(context.Background(), key, value, time.Minute) + assert.NoError(t, err) +} + +func read(t *testing.T, store session.Store, key string, encrypted *session.EncryptedData, crypter crypto.Crypter) *session.Data { + result, err := store.Read(context.Background(), key) + assert.NoError(t, err) + assert.Equal(t, encrypted, result) + + decrypted, err := result.Decrypt(crypter) + assert.NoError(t, err) + + return decrypted +} + +func update(t *testing.T, store session.Store, key string, data *session.Data, crypter crypto.Crypter) (*session.Data, *session.EncryptedData) { + data.AccessToken = "new-access-token" + data.RefreshToken = "new-refresh-token" + encryptedData, err := data.Encrypt(crypter) + assert.NoError(t, err) + + err = store.Update(context.Background(), key, encryptedData) + assert.NoError(t, err) + + return data, encryptedData +} + +func del(t *testing.T, store session.Store, key string) { + err := store.Delete(context.Background(), key) + + result, err := store.Read(context.Background(), key) + assert.Error(t, err) + assert.ErrorIs(t, err, session.KeyNotFoundError) + assert.Nil(t, result) +}