diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index d3761f0..0f2cafc 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -1,19 +1,29 @@ package handler import ( + "context" + "encoding/json" + "errors" + "fmt" "net/http" urllib "net/url" "time" + "github.com/sethvargo/go-retry" + log "github.com/sirupsen/logrus" + "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" "github.com/nais/wonderwall/pkg/crypto" "github.com/nais/wonderwall/pkg/handler/autologin" errorhandler "github.com/nais/wonderwall/pkg/handler/error" "github.com/nais/wonderwall/pkg/ingress" - "github.com/nais/wonderwall/pkg/middleware" + "github.com/nais/wonderwall/pkg/metrics" + mw "github.com/nais/wonderwall/pkg/middleware" + "github.com/nais/wonderwall/pkg/openid" openidclient "github.com/nais/wonderwall/pkg/openid/client" openidconfig "github.com/nais/wonderwall/pkg/openid/config" + retrypkg "github.com/nais/wonderwall/pkg/retry" "github.com/nais/wonderwall/pkg/router" "github.com/nais/wonderwall/pkg/session" "github.com/nais/wonderwall/pkg/url" @@ -34,6 +44,10 @@ type Standalone struct { UpstreamProxy *ReverseProxy } +type LogoutOptions struct { + GlobalLogout bool +} + func NewStandalone( cfg *config.Config, cookieOpts cookie.Options, @@ -114,7 +128,7 @@ func (s *Standalone) GetIngresses() *ingress.Ingresses { } func (s *Standalone) GetPath(r *http.Request) string { - path, ok := middleware.PathFrom(r.Context()) + path, ok := mw.PathFrom(r.Context()) if !ok { path = s.Ingresses.MatchingPath(r) } @@ -135,37 +149,245 @@ func (s *Standalone) GetSessionConfig() config.Session { } func (s *Standalone) Login(w http.ResponseWriter, r *http.Request) { - Login(s, w, r) + canonicalRedirect := s.GetRedirect().Canonical(r) + login, err := s.GetClient().Login(r) + if err != nil { + if errors.Is(err, openidclient.ErrInvalidSecurityLevel) || errors.Is(err, openidclient.ErrInvalidLocale) { + s.GetErrorHandler().BadRequest(w, r, err) + } else { + s.GetErrorHandler().InternalError(w, r, err) + } + + return + } + + opts := s.GetCookieOptsPathAware(r). + WithExpiresIn(1 * time.Hour). + WithSameSite(http.SameSiteNoneMode) + err = login.SetCookie(w, opts, s.GetCrypter(), canonicalRedirect) + if err != nil { + s.GetErrorHandler().InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err)) + return + } + + fields := log.Fields{ + "redirect_after_login": canonicalRedirect, + } + mw.LogEntryFrom(r).WithFields(fields).Info("login: redirecting to identity provider") + http.Redirect(w, r, login.AuthCodeURL(), http.StatusTemporaryRedirect) } func (s *Standalone) LoginCallback(w http.ResponseWriter, r *http.Request) { - LoginCallback(s, w, r) + opts := s.GetCookieOptsPathAware(r) + + // unconditionally clear login cookies + cookie.Clear(w, cookie.Login, opts.WithSameSite(http.SameSiteNoneMode)) + cookie.Clear(w, cookie.LoginLegacy, opts.WithSameSite(http.SameSiteDefaultMode)) + + loginCookie, err := openid.GetLoginCookie(r, s.GetCrypter()) + if err != nil { + msg := "callback: fetching login cookie" + if errors.Is(err, http.ErrNoCookie) { + msg += ": fallback cookie not found (user might have blocked all cookies, or the callback route was accessed before the login route)" + } + s.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("%s: %w", msg, err)) + return + } + + loginCallback, err := s.GetClient().LoginCallback(r, loginCookie) + if err != nil { + s.GetErrorHandler().InternalError(w, r, err) + return + } + + if err := loginCallback.IdentityProviderError(); err != nil { + s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: %w", err)) + return + } + + if err := loginCallback.StateMismatchError(); err != nil { + s.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("callback: %w", err)) + return + } + + var tokens *openid.Tokens + err = retry.Do(r.Context(), retrypkg.DefaultBackoff, func(ctx context.Context) error { + tokens, err = loginCallback.RedeemTokens(ctx) + return retry.RetryableError(err) + }) + if err != nil { + s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: redeeming tokens: %w", err)) + return + } + + sessionLifetime := s.GetSessionConfig().MaxLifetime + + ticket, err := s.GetSessions().Create(r, tokens, sessionLifetime) + if err != nil { + s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: creating session: %w", err)) + return + } + + err = ticket.Set(w, opts.WithExpiresIn(sessionLifetime), s.GetCrypter()) + if err != nil { + s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: setting session cookie: %w", err)) + return + } + + redirect := s.GetRedirect().Clean(r, loginCookie.Referer) + + fields := log.Fields{ + "redirect_to": redirect, + "jti": tokens.IDToken.GetJwtID(), + } + + mw.LogEntryFrom(r).WithFields(fields).Info("callback: successful login") + metrics.ObserveLogin() + cookie.Clear(w, cookie.Retry, s.GetCookieOptsPathAware(r)) + http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) } func (s *Standalone) Logout(w http.ResponseWriter, r *http.Request) { opts := LogoutOptions{ GlobalLogout: true, } - Logout(s, w, r, opts) + s.logout(w, r, opts) } func (s *Standalone) LogoutLocal(w http.ResponseWriter, r *http.Request) { opts := LogoutOptions{ GlobalLogout: false, } - Logout(s, w, r, opts) + s.logout(w, r, opts) +} + +func (s *Standalone) logout(w http.ResponseWriter, r *http.Request, opts LogoutOptions) { + logger := mw.LogEntryFrom(r) + logout, err := s.GetClient().Logout(r) + if err != nil { + s.GetErrorHandler().InternalError(w, r, err) + return + } + + var idToken string + + sessions := s.GetSessions() + + ticket, err := sessions.GetTicket(r) + if err == nil { + sessionData, err := sessions.Get(r, ticket) + if err == nil && sessionData != nil { + idToken = sessionData.IDToken + + err = sessions.Destroy(r, ticket.Key()) + if err != nil && !errors.Is(err, session.ErrKeyNotFound) { + s.GetErrorHandler().InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err)) + return + } + + logger.WithField("jti", sessionData.IDTokenJwtID). + Info("logout: successful local logout") + metrics.ObserveLogout(metrics.LogoutOperationLocal) + } + } + + cookie.Clear(w, cookie.Session, s.GetCookieOptsPathAware(r)) + + if opts.GlobalLogout { + logger.Debug("logout: redirecting to identity provider for global/single-logout") + metrics.ObserveLogout(metrics.LogoutOperationSelfInitiated) + http.Redirect(w, r, logout.SingleLogoutURL(idToken), http.StatusTemporaryRedirect) + } } func (s *Standalone) LogoutCallback(w http.ResponseWriter, r *http.Request) { - LogoutCallback(s, w, r) + redirect := s.GetClient().LogoutCallback(r).PostLogoutRedirectURI() + + cookie.Clear(w, cookie.Retry, s.GetCookieOptsPathAware(r)) + mw.LogEntryFrom(r).Debugf("logout/callback: redirecting to %s", redirect) + http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) } func (s *Standalone) LogoutFrontChannel(w http.ResponseWriter, r *http.Request) { - LogoutFrontChannel(s, w, r) + logger := mw.LogEntryFrom(r) + + // Unconditionally destroy all local references to the session. + cookie.Clear(w, cookie.Session, s.GetCookieOptsPathAware(r)) + + sessions := s.GetSessions() + client := s.GetClient() + + getSessionKey := func(r *http.Request) (string, error) { + lfc := client.LogoutFrontchannel(r) + + if lfc.MissingSidParameter() { + ticket, err := sessions.GetTicket(r) + if err != nil { + return ticket.Key(), nil + } + return "", fmt.Errorf("neither sid parameter nor session ticket found in request: %w", err) + } + + sid := lfc.Sid() + return sessions.Key(sid), nil + } + + key, err := getSessionKey(r) + if err != nil { + logger.Debugf("front-channel logout: getting session key: %+v; ignoring", err) + w.WriteHeader(http.StatusAccepted) + return + } + + err = sessions.Destroy(r, key) + if err != nil { + logger.Warnf("front-channel logout: destroying session: %+v", err) + w.WriteHeader(http.StatusAccepted) + return + } + + cookie.Clear(w, cookie.Retry, s.GetCookieOptsPathAware(r)) + metrics.ObserveLogout(metrics.LogoutOperationFrontChannel) + w.WriteHeader(http.StatusOK) } func (s *Standalone) Session(w http.ResponseWriter, r *http.Request) { - Session(s, w, r) + logger := mw.LogEntryFrom(r) + + ticket, err := s.GetSessions().GetTicket(r) + if err != nil { + logger.Infof("session/refresh: getting ticket: %+v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + + data, err := s.GetSessions().Get(r, ticket) + if err != nil { + switch { + case errors.Is(err, session.ErrInvalidSession), errors.Is(err, session.ErrKeyNotFound): + logger.Infof("session/info: getting session: %+v", err) + w.WriteHeader(http.StatusUnauthorized) + return + default: + logger.Warnf("session/info: getting session: %+v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + } + + w.Header().Set("Content-Type", "application/json") + + if s.GetSessionConfig().Refresh { + err = json.NewEncoder(w).Encode(data.Metadata.VerboseWithRefresh()) + } else { + err = json.NewEncoder(w).Encode(data.Metadata.Verbose()) + } + + if err != nil { + logger.Warnf("session/info: marshalling metadata: %+v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } } func (s *Standalone) SessionRefresh(w http.ResponseWriter, r *http.Request) { @@ -174,7 +396,48 @@ func (s *Standalone) SessionRefresh(w http.ResponseWriter, r *http.Request) { return } - SessionRefresh(s, w, r) + logger := mw.LogEntryFrom(r) + + ticket, err := s.GetSessions().GetTicket(r) + if err != nil { + logger.Infof("session/refresh: getting ticket: %+v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + + data, err := s.GetSessions().Get(r, ticket) + if err != nil { + switch { + case errors.Is(err, session.ErrInvalidSession), errors.Is(err, session.ErrKeyNotFound): + logger.Infof("session/refresh: getting session: %+v", err) + w.WriteHeader(http.StatusUnauthorized) + default: + logger.Warnf("session/refresh: getting session: %+v", err) + w.WriteHeader(http.StatusInternalServerError) + } + return + } + + data, err = s.GetSessions().Refresh(r, ticket, data) + if err != nil { + if errors.Is(err, session.ErrInvalidIdpState) || errors.Is(err, session.ErrInvalidSession) { + logger.Infof("session/refresh: refreshing: %+v", err) + w.WriteHeader(http.StatusUnauthorized) + return + } + + logger.Warnf("session/refresh: refreshing: %+v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(data.Metadata.VerboseWithRefresh()) + if err != nil { + logger.Warnf("session/refresh: marshalling metadata: %+v", err) + w.WriteHeader(http.StatusInternalServerError) + return + } } func (s *Standalone) ReverseProxy(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/handler/handler_sso_proxy.go b/pkg/handler/handler_sso_proxy.go index 5e5b2a1..2793a27 100644 --- a/pkg/handler/handler_sso_proxy.go +++ b/pkg/handler/handler_sso_proxy.go @@ -5,8 +5,11 @@ import ( "net/http" urllib "net/url" + log "github.com/sirupsen/logrus" + "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/ingress" + logentry "github.com/nais/wonderwall/pkg/middleware" openidclient "github.com/nais/wonderwall/pkg/openid/client" "github.com/nais/wonderwall/pkg/router" "github.com/nais/wonderwall/pkg/router/paths" @@ -56,7 +59,31 @@ func NewSSOProxy(cfg *config.Config) (*SSOProxy, error) { } func (s *SSOProxy) Login(w http.ResponseWriter, r *http.Request) { - LoginSSOProxy(s, w, r) + logger := logentry.LogEntryFrom(r) + + target := s.GetSSOServerURL() + targetQuery := target.Query() + + // override default query parameters + reqQuery := r.URL.Query() + if reqQuery.Has(openidclient.SecurityLevelURLParameter) { + targetQuery.Set(openidclient.SecurityLevelURLParameter, reqQuery.Get(openidclient.SecurityLevelURLParameter)) + } + if reqQuery.Has(openidclient.LocaleURLParameter) { + targetQuery.Set(openidclient.LocaleURLParameter, reqQuery.Get(openidclient.LocaleURLParameter)) + } + + target.RawQuery = reqQuery.Encode() + + canonicalRedirect := s.GetRedirect().Canonical(r) + ssoServerLoginURL := url.Login(target, canonicalRedirect) + + logger.WithFields(log.Fields{ + "redirect_to": ssoServerLoginURL, + "redirect_after_login": canonicalRedirect, + }).Info("login: redirecting to sso server") + + http.Redirect(w, r, ssoServerLoginURL, http.StatusTemporaryRedirect) } func (s *SSOProxy) LoginCallback(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/handler/handler_test.go b/pkg/handler/handler_test.go index 6b52270..5d01832 100644 --- a/pkg/handler/handler_test.go +++ b/pkg/handler/handler_test.go @@ -3,6 +3,7 @@ package handler_test import ( "encoding/json" "errors" + "fmt" "io" "net/http" "net/http/httptest" @@ -20,6 +21,407 @@ import ( urlpkg "github.com/nais/wonderwall/pkg/url" ) +func TestLogin(t *testing.T) { + cfg := mock.Config() + idp := mock.NewIdentityProvider(cfg) + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + + resp := localLogin(t, rpClient, idp) + loginURL := resp.Location + + req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/login") + + expectedCallbackURL, err := urlpkg.LoginCallback(req) + assert.NoError(t, err) + + assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host)) + assert.Equal(t, "/authorize", loginURL.Path) + assert.Equal(t, idp.OpenIDConfig.Client().ACRValues(), loginURL.Query().Get("acr_values")) + assert.Equal(t, idp.OpenIDConfig.Client().UILocales(), loginURL.Query().Get("ui_locales")) + assert.Equal(t, idp.OpenIDConfig.Client().ClientID(), loginURL.Query().Get("client_id")) + assert.Equal(t, expectedCallbackURL, loginURL.Query().Get("redirect_uri")) + assert.Equal(t, "S256", loginURL.Query().Get("code_challenge_method")) + assert.ElementsMatch(t, idp.OpenIDConfig.Client().Scopes(), strings.Split(loginURL.Query().Get("scope"), " ")) + assert.NotEmpty(t, loginURL.Query().Get("state")) + assert.NotEmpty(t, loginURL.Query().Get("nonce")) + assert.NotEmpty(t, loginURL.Query().Get("code_challenge")) + + resp = get(t, rpClient, loginURL.String()) + assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + + callbackURL := resp.Location + assert.Equal(t, loginURL.Query().Get("state"), callbackURL.Query().Get("state")) + assert.NotEmpty(t, callbackURL.Query().Get("code")) +} + +func TestCallback(t *testing.T) { + cfg := mock.Config() + idp := mock.NewIdentityProvider(cfg) + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) +} + +func TestCallback_SessionStateRequired(t *testing.T) { + cfg := mock.Config() + idp := mock.NewIdentityProvider(cfg) + idp.OpenIDConfig.TestProvider.WithCheckSessionIFrameSupport(idp.ProviderServer.URL + "/checksession") + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + + resp := authorize(t, rpClient, idp) + + // Get callback URL after successful auth + params := resp.Location.Query() + sessionState := params.Get("session_state") + assert.NotEmpty(t, sessionState) + + callback(t, rpClient, resp) +} + +func TestLogout(t *testing.T) { + cfg := mock.Config() + idp := mock.NewIdentityProvider(cfg) + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + resp := selfInitiatedLogout(t, rpClient, idp) + + // Get endsession endpoint after local logout + endsessionURL := resp.Location + + idpserverURL, err := url.Parse(idp.ProviderServer.URL) + assert.NoError(t, err) + + req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout/callback") + expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req) + assert.NoError(t, err) + + endsessionParams := endsessionURL.Query() + assert.Equal(t, idpserverURL.Host, endsessionURL.Host) + assert.Equal(t, "/endsession", endsessionURL.Path) + assert.Equal(t, []string{expectedLogoutCallbackURL}, endsessionParams["post_logout_redirect_uri"]) + assert.NotEmpty(t, endsessionParams["id_token_hint"]) +} + +func TestLogoutLocal(t *testing.T) { + cfg := mock.Config() + idp := mock.NewIdentityProvider(cfg) + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + localLogout(t, rpClient, idp) +} + +func TestLogoutCallback(t *testing.T) { + cfg := mock.Config() + idp := mock.NewIdentityProvider(cfg) + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + logout(t, rpClient, idp) +} + +func TestFrontChannelLogout(t *testing.T) { + cfg := mock.Config() + idp := mock.NewIdentityProvider(cfg) + idp.OpenIDConfig.TestProvider.WithFrontChannelLogoutSupport() + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + sessionCookie := login(t, rpClient, idp) + + // Trigger front-channel logout + sid := func(r *http.Request) string { + r.AddCookie(sessionCookie) + + ticket, err := session.GetTicket(r, idp.RelyingPartyHandler.GetCrypter()) + assert.NoError(t, err) + + data, err := idp.RelyingPartyHandler.GetSessions().Get(r, ticket) + assert.NoError(t, err) + + return data.ExternalSessionID + } + + frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/frontchannel") + assert.NoError(t, err) + + req := idp.GetRequest(frontchannelLogoutURL.String()) + + values := url.Values{} + values.Add("sid", sid(req)) + values.Add("iss", idp.OpenIDConfig.Provider().Issuer()) + frontchannelLogoutURL.RawQuery = values.Encode() + + resp := get(t, rpClient, frontchannelLogoutURL.String()) + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestSessionRefresh(t *testing.T) { + cfg := mock.Config() + cfg.Session.Refresh = true + + idp := mock.NewIdentityProvider(cfg) + idp.ProviderHandler.TokenDuration = 5 * time.Second + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + // get initial session info + resp := sessionInfo(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var data session.MetadataVerboseWithRefresh + err := json.Unmarshal([]byte(resp.Body), &data) + assert.NoError(t, err) + + // wait until refresh cooldown has reached zero before refresh + waitForRefreshCooldownTimer(t, idp, rpClient) + + resp = sessionRefresh(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var refreshedData session.MetadataVerboseWithRefresh + err = json.Unmarshal([]byte(resp.Body), &refreshedData) + assert.NoError(t, err) + + // session create and end times should be unchanged + assert.WithinDuration(t, data.Session.CreatedAt, refreshedData.Session.CreatedAt, 0) + assert.WithinDuration(t, data.Session.EndsAt, refreshedData.Session.EndsAt, 0) + + // token expiration and refresh times should be later than before + assert.True(t, refreshedData.Tokens.ExpireAt.After(data.Tokens.ExpireAt)) + assert.True(t, refreshedData.Tokens.RefreshedAt.After(data.Tokens.RefreshedAt)) + + allowedSkew := 5 * time.Second + assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), refreshedData.Tokens.ExpireAt, allowedSkew) + assert.WithinDuration(t, time.Now(), refreshedData.Tokens.RefreshedAt, allowedSkew) + + sessionEndDuration := time.Duration(refreshedData.Session.EndsInSeconds) * time.Second + // 1 second < time until session ends <= configured max session lifetime + assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime) + assert.Greater(t, sessionEndDuration, time.Second) + + tokenExpiryDuration := time.Duration(refreshedData.Tokens.ExpireInSeconds) * time.Second + // 1 second < time until token expires <= max duration for tokens from IDP + assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration) + assert.Greater(t, tokenExpiryDuration, time.Second) + + // 1 second < next token refresh <= seconds until token expires + assert.LessOrEqual(t, refreshedData.Tokens.NextAutoRefreshInSeconds, refreshedData.Tokens.ExpireInSeconds) + assert.Greater(t, refreshedData.Tokens.NextAutoRefreshInSeconds, int64(1)) + + assert.True(t, refreshedData.Tokens.RefreshCooldown) + // 1 second < refresh cooldown <= minimum refresh interval + assert.LessOrEqual(t, refreshedData.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval) + assert.Greater(t, refreshedData.Tokens.RefreshCooldownSeconds, int64(1)) + + assert.True(t, data.Session.Active) + assert.True(t, refreshedData.Session.Active) + + assert.True(t, data.Session.TimeoutAt.IsZero()) + assert.True(t, refreshedData.Session.TimeoutAt.IsZero()) + + assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds) + assert.Equal(t, int64(-1), refreshedData.Session.TimeoutInSeconds) +} + +func TestSessionRefresh_Disabled(t *testing.T) { + cfg := mock.Config() + cfg.Session.Refresh = false + + idp := mock.NewIdentityProvider(cfg) + idp.ProviderHandler.TokenDuration = 5 * time.Second + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + resp := sessionRefresh(t, idp, rpClient) + assert.Equal(t, http.StatusNotFound, resp.StatusCode) +} + +func TestSessionRefresh_WithInactivity(t *testing.T) { + cfg := mock.Config() + cfg.Session.Refresh = true + cfg.Session.Inactivity = true + cfg.Session.InactivityTimeout = 10 * time.Minute + + idp := mock.NewIdentityProvider(cfg) + idp.ProviderHandler.TokenDuration = 5 * time.Second + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + // get initial session info + resp := sessionInfo(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var data session.MetadataVerboseWithRefresh + err := json.Unmarshal([]byte(resp.Body), &data) + assert.NoError(t, err) + + // wait until refresh cooldown has reached zero before refresh + waitForRefreshCooldownTimer(t, idp, rpClient) + + resp = sessionRefresh(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var refreshedData session.MetadataVerboseWithRefresh + err = json.Unmarshal([]byte(resp.Body), &refreshedData) + assert.NoError(t, err) + + maxDelta := 5 * time.Second + + assert.True(t, data.Session.Active) + assert.True(t, refreshedData.Session.Active) + + assert.False(t, data.Session.TimeoutAt.IsZero()) + assert.False(t, refreshedData.Session.TimeoutAt.IsZero()) + + expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout) + assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta) + assert.WithinDuration(t, expectedTimeoutAt, refreshedData.Session.TimeoutAt, maxDelta) + + assert.True(t, refreshedData.Session.TimeoutAt.After(data.Session.TimeoutAt)) + + previousTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second + assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(previousTimeoutDuration), maxDelta) + + refreshedTimeoutDuration := time.Duration(refreshedData.Session.TimeoutInSeconds) * time.Second + assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(refreshedTimeoutDuration), maxDelta) +} + +func TestSession(t *testing.T) { + cfg := mock.Config() + cfg.Session.Refresh = true + + idp := mock.NewIdentityProvider(cfg) + idp.ProviderHandler.TokenDuration = 5 * time.Minute + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + resp := sessionInfo(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var data session.MetadataVerbose + err := json.Unmarshal([]byte(resp.Body), &data) + assert.NoError(t, err) + + allowedSkew := 5 * time.Second + assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew) + assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew) + assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew) + assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew) + + sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second + // 1 second < time until session ends <= configured max session lifetime + assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime) + assert.Greater(t, sessionEndDuration, time.Second) + + tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second + // 1 second < time until token expires <= max duration for tokens from IDP + assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration) + assert.Greater(t, tokenExpiryDuration, time.Second) + + assert.True(t, data.Session.Active) + assert.True(t, data.Session.TimeoutAt.IsZero()) + assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds) +} + +func TestSession_WithInactivity(t *testing.T) { + cfg := mock.Config() + cfg.Session.Refresh = true + cfg.Session.Inactivity = true + cfg.Session.InactivityTimeout = 10 * time.Minute + + idp := mock.NewIdentityProvider(cfg) + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + resp := sessionInfo(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var data session.MetadataVerbose + err := json.Unmarshal([]byte(resp.Body), &data) + assert.NoError(t, err) + + maxDelta := 5 * time.Second + + assert.True(t, data.Session.Active) + assert.False(t, data.Session.TimeoutAt.IsZero()) + + expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout) + assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta) + + actualTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second + assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(actualTimeoutDuration), maxDelta) +} + +func TestSession_WithRefresh(t *testing.T) { + cfg := mock.Config() + cfg.Session.Refresh = true + + idp := mock.NewIdentityProvider(cfg) + idp.ProviderHandler.TokenDuration = 5 * time.Minute + defer idp.Close() + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + + resp := sessionInfo(t, idp, rpClient) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var data session.MetadataVerboseWithRefresh + err := json.Unmarshal([]byte(resp.Body), &data) + assert.NoError(t, err) + + allowedSkew := 5 * time.Second + assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew) + assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew) + assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew) + assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew) + + sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second + // 1 second < time until session ends <= configured max session lifetime + assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime) + assert.Greater(t, sessionEndDuration, time.Second) + + tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second + // 1 second < time until token expires <= max duration for tokens from IDP + assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration) + assert.Greater(t, tokenExpiryDuration, time.Second) + + // 1 second < next token refresh <= seconds until token expires + assert.LessOrEqual(t, data.Tokens.NextAutoRefreshInSeconds, data.Tokens.ExpireInSeconds) + assert.Greater(t, data.Tokens.NextAutoRefreshInSeconds, int64(1)) + + assert.True(t, data.Tokens.RefreshCooldown) + // 1 second < refresh cooldown <= minimum refresh interval + assert.LessOrEqual(t, data.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval) + assert.Greater(t, data.Tokens.RefreshCooldownSeconds, int64(1)) + + assert.True(t, data.Session.Active) + assert.True(t, data.Session.TimeoutAt.IsZero()) + assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds) +} + func localLogin(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response { // First, run /oauth2/login to set cookies loginURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/login") diff --git a/pkg/handler/login.go b/pkg/handler/login.go deleted file mode 100644 index 6e14d7e..0000000 --- a/pkg/handler/login.go +++ /dev/null @@ -1,116 +0,0 @@ -package handler - -import ( - "encoding/json" - "errors" - "fmt" - "net/http" - "net/url" - "time" - - log "github.com/sirupsen/logrus" - - "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/crypto" - errorhandler "github.com/nais/wonderwall/pkg/handler/error" - logentry "github.com/nais/wonderwall/pkg/middleware" - "github.com/nais/wonderwall/pkg/openid" - openidclient "github.com/nais/wonderwall/pkg/openid/client" - urlpkg "github.com/nais/wonderwall/pkg/url" -) - -const ( - CookieLifetime = 1 * time.Hour -) - -type LoginSource interface { - GetClient() *openidclient.Client - GetCookieOptsPathAware(r *http.Request) cookie.Options - GetCrypter() crypto.Crypter - GetErrorHandler() errorhandler.Handler - GetRedirect() urlpkg.Redirect -} - -func Login(src LoginSource, w http.ResponseWriter, r *http.Request) { - canonicalRedirect := src.GetRedirect().Canonical(r) - login, err := src.GetClient().Login(r) - if err != nil { - if errors.Is(err, openidclient.ErrInvalidSecurityLevel) || errors.Is(err, openidclient.ErrInvalidLocale) { - src.GetErrorHandler().BadRequest(w, r, err) - } else { - src.GetErrorHandler().InternalError(w, r, err) - } - - return - } - - err = setLoginCookies(src, w, r, login.Cookie(canonicalRedirect)) - if err != nil { - src.GetErrorHandler().InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err)) - return - } - - fields := log.Fields{ - "redirect_after_login": canonicalRedirect, - } - logentry.LogEntryFrom(r).WithFields(fields).Info("login: redirecting to identity provider") - http.Redirect(w, r, login.AuthCodeURL(), http.StatusTemporaryRedirect) -} - -type LoginSSOProxySource interface { - GetSSOServerURL() *url.URL - GetRedirect() urlpkg.Redirect -} - -func LoginSSOProxy(src LoginSSOProxySource, w http.ResponseWriter, r *http.Request) { - logger := logentry.LogEntryFrom(r) - - target := src.GetSSOServerURL() - targetQuery := target.Query() - - // override default query parameters - reqQuery := r.URL.Query() - if reqQuery.Has(openidclient.SecurityLevelURLParameter) { - targetQuery.Set(openidclient.SecurityLevelURLParameter, reqQuery.Get(openidclient.SecurityLevelURLParameter)) - } - if reqQuery.Has(openidclient.LocaleURLParameter) { - targetQuery.Set(openidclient.LocaleURLParameter, reqQuery.Get(openidclient.LocaleURLParameter)) - } - - target.RawQuery = reqQuery.Encode() - - canonicalRedirect := src.GetRedirect().Canonical(r) - ssoServerLoginURL := urlpkg.Login(target, canonicalRedirect) - - logger.WithFields(log.Fields{ - "redirect_to": ssoServerLoginURL, - "redirect_after_login": canonicalRedirect, - }).Info("login: redirecting to sso server") - - http.Redirect(w, r, ssoServerLoginURL, http.StatusTemporaryRedirect) -} - -func setLoginCookies(src LoginSource, w http.ResponseWriter, r *http.Request, loginCookie *openid.LoginCookie) error { - loginCookieJson, err := json.Marshal(loginCookie) - if err != nil { - return fmt.Errorf("marshalling login cookie: %w", err) - } - - opts := src.GetCookieOptsPathAware(r). - WithExpiresIn(CookieLifetime). - WithSameSite(http.SameSiteNoneMode) - value := string(loginCookieJson) - - err = cookie.EncryptAndSet(w, cookie.Login, value, opts, src.GetCrypter()) - if err != nil { - return err - } - - // set a duplicate cookie without the SameSite value set for user agents that do not properly handle SameSite - err = cookie.EncryptAndSet(w, cookie.LoginLegacy, value, opts.WithSameSite(http.SameSiteDefaultMode), src.GetCrypter()) - if err != nil { - return err - } - - return nil -} diff --git a/pkg/handler/login_callback.go b/pkg/handler/login_callback.go deleted file mode 100644 index 35b1ccb..0000000 --- a/pkg/handler/login_callback.go +++ /dev/null @@ -1,125 +0,0 @@ -package handler - -import ( - "context" - "errors" - "fmt" - "net/http" - - "github.com/sethvargo/go-retry" - log "github.com/sirupsen/logrus" - - "github.com/nais/wonderwall/pkg/config" - "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/crypto" - errorhandler "github.com/nais/wonderwall/pkg/handler/error" - "github.com/nais/wonderwall/pkg/metrics" - logentry "github.com/nais/wonderwall/pkg/middleware" - "github.com/nais/wonderwall/pkg/openid" - openidclient "github.com/nais/wonderwall/pkg/openid/client" - retrypkg "github.com/nais/wonderwall/pkg/retry" - "github.com/nais/wonderwall/pkg/session" - "github.com/nais/wonderwall/pkg/url" -) - -type LoginCallbackSource interface { - GetClient() *openidclient.Client - GetCookieOptions() cookie.Options - GetCookieOptsPathAware(r *http.Request) cookie.Options - GetCrypter() crypto.Crypter - GetErrorHandler() errorhandler.Handler - GetRedirect() url.Redirect - GetSessions() *session.Handler - GetSessionConfig() config.Session -} - -func LoginCallback(src LoginCallbackSource, w http.ResponseWriter, r *http.Request) { - // unconditionally clear login cookie - clearLoginCookies(src, w, r) - - loginCookie, err := openid.GetLoginCookie(r, src.GetCrypter()) - if err != nil { - msg := "callback: fetching login cookie" - if errors.Is(err, http.ErrNoCookie) { - msg += ": fallback cookie not found (user might have blocked all cookies, or the callback route was accessed before the login route)" - } - src.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("%s: %w", msg, err)) - return - } - - loginCallback, err := src.GetClient().LoginCallback(r, loginCookie) - if err != nil { - src.GetErrorHandler().InternalError(w, r, err) - return - } - - if err := loginCallback.IdentityProviderError(); err != nil { - src.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: %w", err)) - return - } - - if err := loginCallback.StateMismatchError(); err != nil { - src.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("callback: %w", err)) - return - } - - tokens, err := redeemValidTokens(r, loginCallback) - if err != nil { - src.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: redeeming tokens: %w", err)) - return - } - - sessionLifetime := src.GetSessionConfig().MaxLifetime - - ticket, err := src.GetSessions().Create(r, tokens, sessionLifetime) - if err != nil { - src.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: creating session: %w", err)) - return - } - - opts := src.GetCookieOptsPathAware(r). - WithExpiresIn(sessionLifetime) - err = ticket.Set(w, opts, src.GetCrypter()) - if err != nil { - src.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: setting session cookie: %w", err)) - return - } - - redirect := src.GetRedirect().Clean(r, loginCookie.Referer) - - logSuccessfulLogin(r, tokens, redirect) - cookie.Clear(w, cookie.Retry, src.GetCookieOptsPathAware(r)) - http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) -} - -func clearLoginCookies(src LogoutCallbackSource, w http.ResponseWriter, r *http.Request) { - opts := src.GetCookieOptsPathAware(r) - cookie.Clear(w, cookie.Login, opts.WithSameSite(http.SameSiteNoneMode)) - cookie.Clear(w, cookie.LoginLegacy, opts.WithSameSite(http.SameSiteDefaultMode)) -} - -func redeemValidTokens(r *http.Request, loginCallback *openidclient.LoginCallback) (*openid.Tokens, error) { - var tokens *openid.Tokens - var err error - - retryable := func(ctx context.Context) error { - tokens, err = loginCallback.RedeemTokens(ctx) - return retry.RetryableError(err) - } - - if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { - return nil, err - } - - return tokens, nil -} - -func logSuccessfulLogin(r *http.Request, tokens *openid.Tokens, referer string) { - fields := log.Fields{ - "redirect_to": referer, - "jti": tokens.IDToken.GetJwtID(), - } - - logentry.LogEntryFrom(r).WithFields(fields).Info("callback: successful login") - metrics.ObserveLogin() -} diff --git a/pkg/handler/login_callback_test.go b/pkg/handler/login_callback_test.go deleted file mode 100644 index 8d2ac06..0000000 --- a/pkg/handler/login_callback_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package handler_test - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/nais/wonderwall/pkg/mock" -) - -func TestCallback(t *testing.T) { - cfg := mock.Config() - idp := mock.NewIdentityProvider(cfg) - defer idp.Close() - - rpClient := idp.RelyingPartyClient() - login(t, rpClient, idp) -} - -func TestCallback_SessionStateRequired(t *testing.T) { - cfg := mock.Config() - idp := mock.NewIdentityProvider(cfg) - idp.OpenIDConfig.TestProvider.WithCheckSessionIFrameSupport(idp.ProviderServer.URL + "/checksession") - defer idp.Close() - - rpClient := idp.RelyingPartyClient() - - resp := authorize(t, rpClient, idp) - - // Get callback URL after successful auth - params := resp.Location.Query() - sessionState := params.Get("session_state") - assert.NotEmpty(t, sessionState) - - callback(t, rpClient, resp) -} diff --git a/pkg/handler/login_test.go b/pkg/handler/login_test.go deleted file mode 100644 index 6838875..0000000 --- a/pkg/handler/login_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package handler_test - -import ( - "fmt" - "net/http" - "strings" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/nais/wonderwall/pkg/mock" - urlpkg "github.com/nais/wonderwall/pkg/url" -) - -func TestLogin(t *testing.T) { - cfg := mock.Config() - idp := mock.NewIdentityProvider(cfg) - defer idp.Close() - - rpClient := idp.RelyingPartyClient() - - resp := localLogin(t, rpClient, idp) - loginURL := resp.Location - - req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/login") - - expectedCallbackURL, err := urlpkg.LoginCallback(req) - assert.NoError(t, err) - - assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host)) - assert.Equal(t, "/authorize", loginURL.Path) - assert.Equal(t, idp.OpenIDConfig.Client().ACRValues(), loginURL.Query().Get("acr_values")) - assert.Equal(t, idp.OpenIDConfig.Client().UILocales(), loginURL.Query().Get("ui_locales")) - assert.Equal(t, idp.OpenIDConfig.Client().ClientID(), loginURL.Query().Get("client_id")) - assert.Equal(t, expectedCallbackURL, loginURL.Query().Get("redirect_uri")) - assert.Equal(t, "S256", loginURL.Query().Get("code_challenge_method")) - assert.ElementsMatch(t, idp.OpenIDConfig.Client().Scopes(), strings.Split(loginURL.Query().Get("scope"), " ")) - assert.NotEmpty(t, loginURL.Query().Get("state")) - assert.NotEmpty(t, loginURL.Query().Get("nonce")) - assert.NotEmpty(t, loginURL.Query().Get("code_challenge")) - - resp = get(t, rpClient, loginURL.String()) - assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - - callbackURL := resp.Location - assert.Equal(t, loginURL.Query().Get("state"), callbackURL.Query().Get("state")) - assert.NotEmpty(t, callbackURL.Query().Get("code")) -} diff --git a/pkg/handler/logout.go b/pkg/handler/logout.go deleted file mode 100644 index 7ee93ca..0000000 --- a/pkg/handler/logout.go +++ /dev/null @@ -1,65 +0,0 @@ -package handler - -import ( - "errors" - "fmt" - "net/http" - - "github.com/nais/wonderwall/pkg/cookie" - errorhandler "github.com/nais/wonderwall/pkg/handler/error" - "github.com/nais/wonderwall/pkg/metrics" - logentry "github.com/nais/wonderwall/pkg/middleware" - openidclient "github.com/nais/wonderwall/pkg/openid/client" - "github.com/nais/wonderwall/pkg/session" -) - -type LogoutSource interface { - GetClient() *openidclient.Client - GetCookieOptions() cookie.Options - GetCookieOptsPathAware(r *http.Request) cookie.Options - GetErrorHandler() errorhandler.Handler - GetSessions() *session.Handler -} - -type LogoutOptions struct { - GlobalLogout bool -} - -func Logout(src LogoutSource, w http.ResponseWriter, r *http.Request, opts LogoutOptions) { - logger := logentry.LogEntryFrom(r) - logout, err := src.GetClient().Logout(r) - if err != nil { - src.GetErrorHandler().InternalError(w, r, err) - return - } - - var idToken string - - sessions := src.GetSessions() - - ticket, err := sessions.GetTicket(r) - if err == nil { - sessionData, err := sessions.Get(r, ticket) - if err == nil && sessionData != nil { - idToken = sessionData.IDToken - - err = sessions.Destroy(r, ticket.Key()) - if err != nil && !errors.Is(err, session.ErrKeyNotFound) { - src.GetErrorHandler().InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err)) - return - } - - logger.WithField("jti", sessionData.IDTokenJwtID). - Info("logout: successful local logout") - metrics.ObserveLogout(metrics.LogoutOperationLocal) - } - } - - cookie.Clear(w, cookie.Session, src.GetCookieOptsPathAware(r)) - - if opts.GlobalLogout { - logger.Debug("logout: redirecting to identity provider for global/single-logout") - metrics.ObserveLogout(metrics.LogoutOperationSelfInitiated) - http.Redirect(w, r, logout.SingleLogoutURL(idToken), http.StatusTemporaryRedirect) - } -} diff --git a/pkg/handler/logout_callback.go b/pkg/handler/logout_callback.go deleted file mode 100644 index 5ac3a2b..0000000 --- a/pkg/handler/logout_callback.go +++ /dev/null @@ -1,22 +0,0 @@ -package handler - -import ( - "net/http" - - "github.com/nais/wonderwall/pkg/cookie" - logentry "github.com/nais/wonderwall/pkg/middleware" - openidclient "github.com/nais/wonderwall/pkg/openid/client" -) - -type LogoutCallbackSource interface { - GetClient() *openidclient.Client - GetCookieOptsPathAware(r *http.Request) cookie.Options -} - -func LogoutCallback(src LogoutCallbackSource, w http.ResponseWriter, r *http.Request) { - redirect := src.GetClient().LogoutCallback(r).PostLogoutRedirectURI() - - cookie.Clear(w, cookie.Retry, src.GetCookieOptsPathAware(r)) - logentry.LogEntryFrom(r).Debugf("logout/callback: redirecting to %s", redirect) - http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) -} diff --git a/pkg/handler/logout_callback_test.go b/pkg/handler/logout_callback_test.go deleted file mode 100644 index 38bac6d..0000000 --- a/pkg/handler/logout_callback_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package handler_test - -import ( - "testing" - - "github.com/nais/wonderwall/pkg/mock" -) - -func TestLogoutCallback(t *testing.T) { - cfg := mock.Config() - idp := mock.NewIdentityProvider(cfg) - defer idp.Close() - - rpClient := idp.RelyingPartyClient() - login(t, rpClient, idp) - logout(t, rpClient, idp) -} diff --git a/pkg/handler/logout_frontchannel.go b/pkg/handler/logout_frontchannel.go deleted file mode 100644 index e1a105b..0000000 --- a/pkg/handler/logout_frontchannel.go +++ /dev/null @@ -1,61 +0,0 @@ -package handler - -import ( - "fmt" - "net/http" - - "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/metrics" - mw "github.com/nais/wonderwall/pkg/middleware" - openidclient "github.com/nais/wonderwall/pkg/openid/client" - "github.com/nais/wonderwall/pkg/session" -) - -type LogoutFrontChannelSource interface { - GetClient() *openidclient.Client - GetCookieOptions() cookie.Options - GetCookieOptsPathAware(r *http.Request) cookie.Options - GetSessions() *session.Handler -} - -func LogoutFrontChannel(src LogoutFrontChannelSource, w http.ResponseWriter, r *http.Request) { - logger := mw.LogEntryFrom(r) - - // Unconditionally destroy all local references to the session. - cookie.Clear(w, cookie.Session, src.GetCookieOptsPathAware(r)) - - sessions := src.GetSessions() - client := src.GetClient() - key, err := getSessionKey(r, sessions, client) - if err != nil { - logger.Debugf("front-channel logout: getting session key: %+v; ignoring", err) - w.WriteHeader(http.StatusAccepted) - return - } - - err = sessions.Destroy(r, key) - if err != nil { - logger.Warnf("front-channel logout: destroying session: %+v", err) - w.WriteHeader(http.StatusAccepted) - return - } - - cookie.Clear(w, cookie.Retry, src.GetCookieOptsPathAware(r)) - metrics.ObserveLogout(metrics.LogoutOperationFrontChannel) - w.WriteHeader(http.StatusOK) -} - -func getSessionKey(r *http.Request, sessions *session.Handler, client *openidclient.Client) (string, error) { - logoutFrontchannel := client.LogoutFrontchannel(r) - - if logoutFrontchannel.MissingSidParameter() { - ticket, err := sessions.GetTicket(r) - if err != nil { - return ticket.Key(), nil - } - return "", fmt.Errorf("neither sid parameter nor session ticket found in request: %w", err) - } - - sid := logoutFrontchannel.Sid() - return sessions.Key(sid), nil -} diff --git a/pkg/handler/logout_frontchannel_test.go b/pkg/handler/logout_frontchannel_test.go deleted file mode 100644 index 4a7a287..0000000 --- a/pkg/handler/logout_frontchannel_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package handler_test - -import ( - "net/http" - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/nais/wonderwall/pkg/mock" - "github.com/nais/wonderwall/pkg/session" -) - -func TestFrontChannelLogout(t *testing.T) { - cfg := mock.Config() - idp := mock.NewIdentityProvider(cfg) - idp.OpenIDConfig.TestProvider.WithFrontChannelLogoutSupport() - defer idp.Close() - - rpClient := idp.RelyingPartyClient() - sessionCookie := login(t, rpClient, idp) - - // Trigger front-channel logout - sid := func(r *http.Request) string { - r.AddCookie(sessionCookie) - - ticket, err := session.GetTicket(r, idp.RelyingPartyHandler.GetCrypter()) - assert.NoError(t, err) - - data, err := idp.RelyingPartyHandler.GetSessions().Get(r, ticket) - assert.NoError(t, err) - - return data.ExternalSessionID - } - - frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/frontchannel") - assert.NoError(t, err) - - req := idp.GetRequest(frontchannelLogoutURL.String()) - - values := url.Values{} - values.Add("sid", sid(req)) - values.Add("iss", idp.OpenIDConfig.Provider().Issuer()) - frontchannelLogoutURL.RawQuery = values.Encode() - - resp := get(t, rpClient, frontchannelLogoutURL.String()) - assert.Equal(t, http.StatusOK, resp.StatusCode) -} diff --git a/pkg/handler/logout_test.go b/pkg/handler/logout_test.go deleted file mode 100644 index 8e0a444..0000000 --- a/pkg/handler/logout_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package handler_test - -import ( - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/nais/wonderwall/pkg/mock" - urlpkg "github.com/nais/wonderwall/pkg/url" -) - -func TestLogout(t *testing.T) { - cfg := mock.Config() - idp := mock.NewIdentityProvider(cfg) - defer idp.Close() - - rpClient := idp.RelyingPartyClient() - login(t, rpClient, idp) - - resp := selfInitiatedLogout(t, rpClient, idp) - - // Get endsession endpoint after local logout - endsessionURL := resp.Location - - idpserverURL, err := url.Parse(idp.ProviderServer.URL) - assert.NoError(t, err) - - req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout/callback") - expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req) - assert.NoError(t, err) - - endsessionParams := endsessionURL.Query() - assert.Equal(t, idpserverURL.Host, endsessionURL.Host) - assert.Equal(t, "/endsession", endsessionURL.Path) - assert.Equal(t, []string{expectedLogoutCallbackURL}, endsessionParams["post_logout_redirect_uri"]) - assert.NotEmpty(t, endsessionParams["id_token_hint"]) -} - -func TestLogoutLocal(t *testing.T) { - cfg := mock.Config() - idp := mock.NewIdentityProvider(cfg) - defer idp.Close() - - rpClient := idp.RelyingPartyClient() - login(t, rpClient, idp) - - localLogout(t, rpClient, idp) -} diff --git a/pkg/handler/session.go b/pkg/handler/session.go deleted file mode 100644 index fabb838..0000000 --- a/pkg/handler/session.go +++ /dev/null @@ -1,55 +0,0 @@ -package handler - -import ( - "encoding/json" - "errors" - "net/http" - - "github.com/nais/wonderwall/pkg/config" - mw "github.com/nais/wonderwall/pkg/middleware" - "github.com/nais/wonderwall/pkg/session" -) - -type SessionSource interface { - GetSessions() *session.Handler - GetSessionConfig() config.Session -} - -func Session(src SessionSource, w http.ResponseWriter, r *http.Request) { - logger := mw.LogEntryFrom(r) - - ticket, err := src.GetSessions().GetTicket(r) - if err != nil { - logger.Infof("session/refresh: getting ticket: %+v", err) - w.WriteHeader(http.StatusUnauthorized) - return - } - - data, err := src.GetSessions().Get(r, ticket) - if err != nil { - switch { - case errors.Is(err, session.ErrInvalidSession), errors.Is(err, session.ErrKeyNotFound): - logger.Infof("session/info: getting session: %+v", err) - w.WriteHeader(http.StatusUnauthorized) - return - default: - logger.Warnf("session/info: getting session: %+v", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - } - - w.Header().Set("Content-Type", "application/json") - - if src.GetSessionConfig().Refresh { - err = json.NewEncoder(w).Encode(data.Metadata.VerboseWithRefresh()) - } else { - err = json.NewEncoder(w).Encode(data.Metadata.Verbose()) - } - - if err != nil { - logger.Warnf("session/info: marshalling metadata: %+v", err) - w.WriteHeader(http.StatusInternalServerError) - return - } -} diff --git a/pkg/handler/session_refresh.go b/pkg/handler/session_refresh.go deleted file mode 100644 index 82a3063..0000000 --- a/pkg/handler/session_refresh.go +++ /dev/null @@ -1,59 +0,0 @@ -package handler - -import ( - "encoding/json" - "errors" - "net/http" - - mw "github.com/nais/wonderwall/pkg/middleware" - "github.com/nais/wonderwall/pkg/session" -) - -type SessionRefreshSource interface { - GetSessions() *session.Handler -} - -func SessionRefresh(src SessionRefreshSource, w http.ResponseWriter, r *http.Request) { - logger := mw.LogEntryFrom(r) - - ticket, err := src.GetSessions().GetTicket(r) - if err != nil { - logger.Infof("session/refresh: getting ticket: %+v", err) - w.WriteHeader(http.StatusUnauthorized) - return - } - - data, err := src.GetSessions().Get(r, ticket) - if err != nil { - switch { - case errors.Is(err, session.ErrInvalidSession), errors.Is(err, session.ErrKeyNotFound): - logger.Infof("session/refresh: getting session: %+v", err) - w.WriteHeader(http.StatusUnauthorized) - default: - logger.Warnf("session/refresh: getting session: %+v", err) - w.WriteHeader(http.StatusInternalServerError) - } - return - } - - data, err = src.GetSessions().Refresh(r, ticket, data) - if err != nil { - if errors.Is(err, session.ErrInvalidIdpState) || errors.Is(err, session.ErrInvalidSession) { - logger.Infof("session/refresh: refreshing: %+v", err) - w.WriteHeader(http.StatusUnauthorized) - return - } - - logger.Warnf("session/refresh: refreshing: %+v", err) - w.WriteHeader(http.StatusInternalServerError) - return - } - - w.Header().Set("Content-Type", "application/json") - err = json.NewEncoder(w).Encode(data.Metadata.VerboseWithRefresh()) - if err != nil { - logger.Warnf("session/refresh: marshalling metadata: %+v", err) - w.WriteHeader(http.StatusInternalServerError) - return - } -} diff --git a/pkg/handler/session_refresh_test.go b/pkg/handler/session_refresh_test.go deleted file mode 100644 index 4f1df88..0000000 --- a/pkg/handler/session_refresh_test.go +++ /dev/null @@ -1,150 +0,0 @@ -package handler_test - -import ( - "encoding/json" - "net/http" - "testing" - "time" - - "github.com/stretchr/testify/assert" - - "github.com/nais/wonderwall/pkg/mock" - "github.com/nais/wonderwall/pkg/session" -) - -func TestSessionRefresh(t *testing.T) { - cfg := mock.Config() - cfg.Session.Refresh = true - - idp := mock.NewIdentityProvider(cfg) - idp.ProviderHandler.TokenDuration = 5 * time.Second - defer idp.Close() - - rpClient := idp.RelyingPartyClient() - login(t, rpClient, idp) - - // get initial session info - resp := sessionInfo(t, idp, rpClient) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - var data session.MetadataVerboseWithRefresh - err := json.Unmarshal([]byte(resp.Body), &data) - assert.NoError(t, err) - - // wait until refresh cooldown has reached zero before refresh - waitForRefreshCooldownTimer(t, idp, rpClient) - - resp = sessionRefresh(t, idp, rpClient) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - var refreshedData session.MetadataVerboseWithRefresh - err = json.Unmarshal([]byte(resp.Body), &refreshedData) - assert.NoError(t, err) - - // session create and end times should be unchanged - assert.WithinDuration(t, data.Session.CreatedAt, refreshedData.Session.CreatedAt, 0) - assert.WithinDuration(t, data.Session.EndsAt, refreshedData.Session.EndsAt, 0) - - // token expiration and refresh times should be later than before - assert.True(t, refreshedData.Tokens.ExpireAt.After(data.Tokens.ExpireAt)) - assert.True(t, refreshedData.Tokens.RefreshedAt.After(data.Tokens.RefreshedAt)) - - allowedSkew := 5 * time.Second - assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), refreshedData.Tokens.ExpireAt, allowedSkew) - assert.WithinDuration(t, time.Now(), refreshedData.Tokens.RefreshedAt, allowedSkew) - - sessionEndDuration := time.Duration(refreshedData.Session.EndsInSeconds) * time.Second - // 1 second < time until session ends <= configured max session lifetime - assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime) - assert.Greater(t, sessionEndDuration, time.Second) - - tokenExpiryDuration := time.Duration(refreshedData.Tokens.ExpireInSeconds) * time.Second - // 1 second < time until token expires <= max duration for tokens from IDP - assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration) - assert.Greater(t, tokenExpiryDuration, time.Second) - - // 1 second < next token refresh <= seconds until token expires - assert.LessOrEqual(t, refreshedData.Tokens.NextAutoRefreshInSeconds, refreshedData.Tokens.ExpireInSeconds) - assert.Greater(t, refreshedData.Tokens.NextAutoRefreshInSeconds, int64(1)) - - assert.True(t, refreshedData.Tokens.RefreshCooldown) - // 1 second < refresh cooldown <= minimum refresh interval - assert.LessOrEqual(t, refreshedData.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval) - assert.Greater(t, refreshedData.Tokens.RefreshCooldownSeconds, int64(1)) - - assert.True(t, data.Session.Active) - assert.True(t, refreshedData.Session.Active) - - assert.True(t, data.Session.TimeoutAt.IsZero()) - assert.True(t, refreshedData.Session.TimeoutAt.IsZero()) - - assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds) - assert.Equal(t, int64(-1), refreshedData.Session.TimeoutInSeconds) -} - -func TestSessionRefresh_Disabled(t *testing.T) { - cfg := mock.Config() - cfg.Session.Refresh = false - - idp := mock.NewIdentityProvider(cfg) - idp.ProviderHandler.TokenDuration = 5 * time.Second - defer idp.Close() - - rpClient := idp.RelyingPartyClient() - login(t, rpClient, idp) - - resp := sessionRefresh(t, idp, rpClient) - assert.Equal(t, http.StatusNotFound, resp.StatusCode) -} - -func TestSessionRefresh_WithInactivity(t *testing.T) { - cfg := mock.Config() - cfg.Session.Refresh = true - cfg.Session.Inactivity = true - cfg.Session.InactivityTimeout = 10 * time.Minute - - idp := mock.NewIdentityProvider(cfg) - idp.ProviderHandler.TokenDuration = 5 * time.Second - defer idp.Close() - - rpClient := idp.RelyingPartyClient() - login(t, rpClient, idp) - - // get initial session info - resp := sessionInfo(t, idp, rpClient) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - var data session.MetadataVerboseWithRefresh - err := json.Unmarshal([]byte(resp.Body), &data) - assert.NoError(t, err) - - // wait until refresh cooldown has reached zero before refresh - waitForRefreshCooldownTimer(t, idp, rpClient) - - resp = sessionRefresh(t, idp, rpClient) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - var refreshedData session.MetadataVerboseWithRefresh - err = json.Unmarshal([]byte(resp.Body), &refreshedData) - assert.NoError(t, err) - - maxDelta := 5 * time.Second - - assert.True(t, data.Session.Active) - assert.True(t, refreshedData.Session.Active) - - assert.False(t, data.Session.TimeoutAt.IsZero()) - assert.False(t, refreshedData.Session.TimeoutAt.IsZero()) - - expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout) - assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta) - assert.WithinDuration(t, expectedTimeoutAt, refreshedData.Session.TimeoutAt, maxDelta) - - assert.True(t, refreshedData.Session.TimeoutAt.After(data.Session.TimeoutAt)) - - previousTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second - assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(previousTimeoutDuration), maxDelta) - - refreshedTimeoutDuration := time.Duration(refreshedData.Session.TimeoutInSeconds) * time.Second - assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(refreshedTimeoutDuration), maxDelta) -} diff --git a/pkg/handler/session_test.go b/pkg/handler/session_test.go deleted file mode 100644 index 7c361dc..0000000 --- a/pkg/handler/session_test.go +++ /dev/null @@ -1,131 +0,0 @@ -package handler_test - -import ( - "encoding/json" - "net/http" - "testing" - "time" - - "github.com/stretchr/testify/assert" - - "github.com/nais/wonderwall/pkg/mock" - "github.com/nais/wonderwall/pkg/session" -) - -func TestSession(t *testing.T) { - cfg := mock.Config() - cfg.Session.Refresh = true - - idp := mock.NewIdentityProvider(cfg) - idp.ProviderHandler.TokenDuration = 5 * time.Minute - defer idp.Close() - - rpClient := idp.RelyingPartyClient() - login(t, rpClient, idp) - - resp := sessionInfo(t, idp, rpClient) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - var data session.MetadataVerbose - err := json.Unmarshal([]byte(resp.Body), &data) - assert.NoError(t, err) - - allowedSkew := 5 * time.Second - assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew) - assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew) - assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew) - assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew) - - sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second - // 1 second < time until session ends <= configured max session lifetime - assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime) - assert.Greater(t, sessionEndDuration, time.Second) - - tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second - // 1 second < time until token expires <= max duration for tokens from IDP - assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration) - assert.Greater(t, tokenExpiryDuration, time.Second) - - assert.True(t, data.Session.Active) - assert.True(t, data.Session.TimeoutAt.IsZero()) - assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds) -} - -func TestSession_WithInactivity(t *testing.T) { - cfg := mock.Config() - cfg.Session.Refresh = true - cfg.Session.Inactivity = true - cfg.Session.InactivityTimeout = 10 * time.Minute - - idp := mock.NewIdentityProvider(cfg) - defer idp.Close() - - rpClient := idp.RelyingPartyClient() - login(t, rpClient, idp) - - resp := sessionInfo(t, idp, rpClient) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - var data session.MetadataVerbose - err := json.Unmarshal([]byte(resp.Body), &data) - assert.NoError(t, err) - - maxDelta := 5 * time.Second - - assert.True(t, data.Session.Active) - assert.False(t, data.Session.TimeoutAt.IsZero()) - - expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout) - assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta) - - actualTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second - assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(actualTimeoutDuration), maxDelta) -} - -func TestSession_WithRefresh(t *testing.T) { - cfg := mock.Config() - cfg.Session.Refresh = true - - idp := mock.NewIdentityProvider(cfg) - idp.ProviderHandler.TokenDuration = 5 * time.Minute - defer idp.Close() - - rpClient := idp.RelyingPartyClient() - login(t, rpClient, idp) - - resp := sessionInfo(t, idp, rpClient) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - var data session.MetadataVerboseWithRefresh - err := json.Unmarshal([]byte(resp.Body), &data) - assert.NoError(t, err) - - allowedSkew := 5 * time.Second - assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew) - assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew) - assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew) - assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew) - - sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second - // 1 second < time until session ends <= configured max session lifetime - assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime) - assert.Greater(t, sessionEndDuration, time.Second) - - tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second - // 1 second < time until token expires <= max duration for tokens from IDP - assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration) - assert.Greater(t, tokenExpiryDuration, time.Second) - - // 1 second < next token refresh <= seconds until token expires - assert.LessOrEqual(t, data.Tokens.NextAutoRefreshInSeconds, data.Tokens.ExpireInSeconds) - assert.Greater(t, data.Tokens.NextAutoRefreshInSeconds, int64(1)) - - assert.True(t, data.Tokens.RefreshCooldown) - // 1 second < refresh cooldown <= minimum refresh interval - assert.LessOrEqual(t, data.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval) - assert.Greater(t, data.Tokens.RefreshCooldownSeconds, int64(1)) - - assert.True(t, data.Session.Active) - assert.True(t, data.Session.TimeoutAt.IsZero()) - assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds) -} diff --git a/pkg/openid/client/login.go b/pkg/openid/client/login.go index d50f30e..a6e5f25 100644 --- a/pkg/openid/client/login.go +++ b/pkg/openid/client/login.go @@ -3,12 +3,15 @@ package client import ( "crypto/sha256" "encoding/base64" + "encoding/json" "errors" "fmt" "net/http" "golang.org/x/oauth2" + "github.com/nais/wonderwall/pkg/cookie" + "github.com/nais/wonderwall/pkg/crypto" "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/config" "github.com/nais/wonderwall/pkg/strings" @@ -52,11 +55,11 @@ func NewLogin(c *Client, r *http.Request) (*Login, error) { return nil, fmt.Errorf("generating auth code url: %w", err) } - cookie := params.cookie(callbackURL) + loginCookie := params.cookie(callbackURL) return &Login{ authCodeURL: url, - cookie: cookie, + cookie: loginCookie, params: params, }, nil } @@ -79,11 +82,6 @@ func (l *Login) CodeVerifier() string { return l.params.CodeVerifier } -func (l *Login) Cookie(canonicalRedirect string) *openid.LoginCookie { - l.cookie.Referer = canonicalRedirect - return l.cookie -} - func (l *Login) Nonce() string { return l.params.Nonce } @@ -92,6 +90,30 @@ func (l *Login) State() string { return l.params.State } +func (l *Login) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter, canonicalRedirect string) error { + l.cookie.Referer = canonicalRedirect + + loginCookieJson, err := json.Marshal(l.cookie) + if err != nil { + return fmt.Errorf("marshalling login cookie: %w", err) + } + + value := string(loginCookieJson) + + err = cookie.EncryptAndSet(w, cookie.Login, value, opts, crypter) + if err != nil { + return err + } + + // set a duplicate cookie without the SameSite value set for user agents that do not properly handle SameSite + err = cookie.EncryptAndSet(w, cookie.LoginLegacy, value, opts.WithSameSite(http.SameSiteDefaultMode), crypter) + if err != nil { + return err + } + + return nil +} + type loginParameters struct { *Client CodeVerifier string diff --git a/pkg/openid/client/login_test.go b/pkg/openid/client/login_test.go index 50f858b..551ee5f 100644 --- a/pkg/openid/client/login_test.go +++ b/pkg/openid/client/login_test.go @@ -103,6 +103,8 @@ func TestLogin_URL(t *testing.T) { assert.ElementsMatch(t, query["code_challenge"], []string{result.CodeChallenge()}) assert.ElementsMatch(t, query["code_challenge_method"], []string{"S256"}) + assert.Equal(t, client.CodeChallenge(result.CodeVerifier()), result.CodeChallenge()) + if test.extraParams != nil { for key, value := range test.extraParams { assert.Contains(t, query, key)