refactor: consolidate handlers

This commit is contained in:
Trong Huu Nguyen
2023-02-16 10:55:50 +01:00
parent 3274cc5c65
commit fb28da7241
19 changed files with 734 additions and 1000 deletions

View File

@@ -1,19 +1,29 @@
package handler
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
urllib "net/url"
"time"
"github.com/sethvargo/go-retry"
log "github.com/sirupsen/logrus"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/handler/autologin"
errorhandler "github.com/nais/wonderwall/pkg/handler/error"
"github.com/nais/wonderwall/pkg/ingress"
"github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/metrics"
mw "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/openid"
openidclient "github.com/nais/wonderwall/pkg/openid/client"
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
retrypkg "github.com/nais/wonderwall/pkg/retry"
"github.com/nais/wonderwall/pkg/router"
"github.com/nais/wonderwall/pkg/session"
"github.com/nais/wonderwall/pkg/url"
@@ -34,6 +44,10 @@ type Standalone struct {
UpstreamProxy *ReverseProxy
}
type LogoutOptions struct {
GlobalLogout bool
}
func NewStandalone(
cfg *config.Config,
cookieOpts cookie.Options,
@@ -114,7 +128,7 @@ func (s *Standalone) GetIngresses() *ingress.Ingresses {
}
func (s *Standalone) GetPath(r *http.Request) string {
path, ok := middleware.PathFrom(r.Context())
path, ok := mw.PathFrom(r.Context())
if !ok {
path = s.Ingresses.MatchingPath(r)
}
@@ -135,37 +149,245 @@ func (s *Standalone) GetSessionConfig() config.Session {
}
func (s *Standalone) Login(w http.ResponseWriter, r *http.Request) {
Login(s, w, r)
canonicalRedirect := s.GetRedirect().Canonical(r)
login, err := s.GetClient().Login(r)
if err != nil {
if errors.Is(err, openidclient.ErrInvalidSecurityLevel) || errors.Is(err, openidclient.ErrInvalidLocale) {
s.GetErrorHandler().BadRequest(w, r, err)
} else {
s.GetErrorHandler().InternalError(w, r, err)
}
return
}
opts := s.GetCookieOptsPathAware(r).
WithExpiresIn(1 * time.Hour).
WithSameSite(http.SameSiteNoneMode)
err = login.SetCookie(w, opts, s.GetCrypter(), canonicalRedirect)
if err != nil {
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err))
return
}
fields := log.Fields{
"redirect_after_login": canonicalRedirect,
}
mw.LogEntryFrom(r).WithFields(fields).Info("login: redirecting to identity provider")
http.Redirect(w, r, login.AuthCodeURL(), http.StatusTemporaryRedirect)
}
func (s *Standalone) LoginCallback(w http.ResponseWriter, r *http.Request) {
LoginCallback(s, w, r)
opts := s.GetCookieOptsPathAware(r)
// unconditionally clear login cookies
cookie.Clear(w, cookie.Login, opts.WithSameSite(http.SameSiteNoneMode))
cookie.Clear(w, cookie.LoginLegacy, opts.WithSameSite(http.SameSiteDefaultMode))
loginCookie, err := openid.GetLoginCookie(r, s.GetCrypter())
if err != nil {
msg := "callback: fetching login cookie"
if errors.Is(err, http.ErrNoCookie) {
msg += ": fallback cookie not found (user might have blocked all cookies, or the callback route was accessed before the login route)"
}
s.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("%s: %w", msg, err))
return
}
loginCallback, err := s.GetClient().LoginCallback(r, loginCookie)
if err != nil {
s.GetErrorHandler().InternalError(w, r, err)
return
}
if err := loginCallback.IdentityProviderError(); err != nil {
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: %w", err))
return
}
if err := loginCallback.StateMismatchError(); err != nil {
s.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("callback: %w", err))
return
}
var tokens *openid.Tokens
err = retry.Do(r.Context(), retrypkg.DefaultBackoff, func(ctx context.Context) error {
tokens, err = loginCallback.RedeemTokens(ctx)
return retry.RetryableError(err)
})
if err != nil {
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: redeeming tokens: %w", err))
return
}
sessionLifetime := s.GetSessionConfig().MaxLifetime
ticket, err := s.GetSessions().Create(r, tokens, sessionLifetime)
if err != nil {
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
return
}
err = ticket.Set(w, opts.WithExpiresIn(sessionLifetime), s.GetCrypter())
if err != nil {
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: setting session cookie: %w", err))
return
}
redirect := s.GetRedirect().Clean(r, loginCookie.Referer)
fields := log.Fields{
"redirect_to": redirect,
"jti": tokens.IDToken.GetJwtID(),
}
mw.LogEntryFrom(r).WithFields(fields).Info("callback: successful login")
metrics.ObserveLogin()
cookie.Clear(w, cookie.Retry, s.GetCookieOptsPathAware(r))
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
}
func (s *Standalone) Logout(w http.ResponseWriter, r *http.Request) {
opts := LogoutOptions{
GlobalLogout: true,
}
Logout(s, w, r, opts)
s.logout(w, r, opts)
}
func (s *Standalone) LogoutLocal(w http.ResponseWriter, r *http.Request) {
opts := LogoutOptions{
GlobalLogout: false,
}
Logout(s, w, r, opts)
s.logout(w, r, opts)
}
func (s *Standalone) logout(w http.ResponseWriter, r *http.Request, opts LogoutOptions) {
logger := mw.LogEntryFrom(r)
logout, err := s.GetClient().Logout(r)
if err != nil {
s.GetErrorHandler().InternalError(w, r, err)
return
}
var idToken string
sessions := s.GetSessions()
ticket, err := sessions.GetTicket(r)
if err == nil {
sessionData, err := sessions.Get(r, ticket)
if err == nil && sessionData != nil {
idToken = sessionData.IDToken
err = sessions.Destroy(r, ticket.Key())
if err != nil && !errors.Is(err, session.ErrKeyNotFound) {
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
return
}
logger.WithField("jti", sessionData.IDTokenJwtID).
Info("logout: successful local logout")
metrics.ObserveLogout(metrics.LogoutOperationLocal)
}
}
cookie.Clear(w, cookie.Session, s.GetCookieOptsPathAware(r))
if opts.GlobalLogout {
logger.Debug("logout: redirecting to identity provider for global/single-logout")
metrics.ObserveLogout(metrics.LogoutOperationSelfInitiated)
http.Redirect(w, r, logout.SingleLogoutURL(idToken), http.StatusTemporaryRedirect)
}
}
func (s *Standalone) LogoutCallback(w http.ResponseWriter, r *http.Request) {
LogoutCallback(s, w, r)
redirect := s.GetClient().LogoutCallback(r).PostLogoutRedirectURI()
cookie.Clear(w, cookie.Retry, s.GetCookieOptsPathAware(r))
mw.LogEntryFrom(r).Debugf("logout/callback: redirecting to %s", redirect)
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
}
func (s *Standalone) LogoutFrontChannel(w http.ResponseWriter, r *http.Request) {
LogoutFrontChannel(s, w, r)
logger := mw.LogEntryFrom(r)
// Unconditionally destroy all local references to the session.
cookie.Clear(w, cookie.Session, s.GetCookieOptsPathAware(r))
sessions := s.GetSessions()
client := s.GetClient()
getSessionKey := func(r *http.Request) (string, error) {
lfc := client.LogoutFrontchannel(r)
if lfc.MissingSidParameter() {
ticket, err := sessions.GetTicket(r)
if err != nil {
return ticket.Key(), nil
}
return "", fmt.Errorf("neither sid parameter nor session ticket found in request: %w", err)
}
sid := lfc.Sid()
return sessions.Key(sid), nil
}
key, err := getSessionKey(r)
if err != nil {
logger.Debugf("front-channel logout: getting session key: %+v; ignoring", err)
w.WriteHeader(http.StatusAccepted)
return
}
err = sessions.Destroy(r, key)
if err != nil {
logger.Warnf("front-channel logout: destroying session: %+v", err)
w.WriteHeader(http.StatusAccepted)
return
}
cookie.Clear(w, cookie.Retry, s.GetCookieOptsPathAware(r))
metrics.ObserveLogout(metrics.LogoutOperationFrontChannel)
w.WriteHeader(http.StatusOK)
}
func (s *Standalone) Session(w http.ResponseWriter, r *http.Request) {
Session(s, w, r)
logger := mw.LogEntryFrom(r)
ticket, err := s.GetSessions().GetTicket(r)
if err != nil {
logger.Infof("session/refresh: getting ticket: %+v", err)
w.WriteHeader(http.StatusUnauthorized)
return
}
data, err := s.GetSessions().Get(r, ticket)
if err != nil {
switch {
case errors.Is(err, session.ErrInvalidSession), errors.Is(err, session.ErrKeyNotFound):
logger.Infof("session/info: getting session: %+v", err)
w.WriteHeader(http.StatusUnauthorized)
return
default:
logger.Warnf("session/info: getting session: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}
w.Header().Set("Content-Type", "application/json")
if s.GetSessionConfig().Refresh {
err = json.NewEncoder(w).Encode(data.Metadata.VerboseWithRefresh())
} else {
err = json.NewEncoder(w).Encode(data.Metadata.Verbose())
}
if err != nil {
logger.Warnf("session/info: marshalling metadata: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}
func (s *Standalone) SessionRefresh(w http.ResponseWriter, r *http.Request) {
@@ -174,7 +396,48 @@ func (s *Standalone) SessionRefresh(w http.ResponseWriter, r *http.Request) {
return
}
SessionRefresh(s, w, r)
logger := mw.LogEntryFrom(r)
ticket, err := s.GetSessions().GetTicket(r)
if err != nil {
logger.Infof("session/refresh: getting ticket: %+v", err)
w.WriteHeader(http.StatusUnauthorized)
return
}
data, err := s.GetSessions().Get(r, ticket)
if err != nil {
switch {
case errors.Is(err, session.ErrInvalidSession), errors.Is(err, session.ErrKeyNotFound):
logger.Infof("session/refresh: getting session: %+v", err)
w.WriteHeader(http.StatusUnauthorized)
default:
logger.Warnf("session/refresh: getting session: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
}
return
}
data, err = s.GetSessions().Refresh(r, ticket, data)
if err != nil {
if errors.Is(err, session.ErrInvalidIdpState) || errors.Is(err, session.ErrInvalidSession) {
logger.Infof("session/refresh: refreshing: %+v", err)
w.WriteHeader(http.StatusUnauthorized)
return
}
logger.Warnf("session/refresh: refreshing: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(data.Metadata.VerboseWithRefresh())
if err != nil {
logger.Warnf("session/refresh: marshalling metadata: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}
func (s *Standalone) ReverseProxy(w http.ResponseWriter, r *http.Request) {

View File

@@ -5,8 +5,11 @@ import (
"net/http"
urllib "net/url"
log "github.com/sirupsen/logrus"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/ingress"
logentry "github.com/nais/wonderwall/pkg/middleware"
openidclient "github.com/nais/wonderwall/pkg/openid/client"
"github.com/nais/wonderwall/pkg/router"
"github.com/nais/wonderwall/pkg/router/paths"
@@ -56,7 +59,31 @@ func NewSSOProxy(cfg *config.Config) (*SSOProxy, error) {
}
func (s *SSOProxy) Login(w http.ResponseWriter, r *http.Request) {
LoginSSOProxy(s, w, r)
logger := logentry.LogEntryFrom(r)
target := s.GetSSOServerURL()
targetQuery := target.Query()
// override default query parameters
reqQuery := r.URL.Query()
if reqQuery.Has(openidclient.SecurityLevelURLParameter) {
targetQuery.Set(openidclient.SecurityLevelURLParameter, reqQuery.Get(openidclient.SecurityLevelURLParameter))
}
if reqQuery.Has(openidclient.LocaleURLParameter) {
targetQuery.Set(openidclient.LocaleURLParameter, reqQuery.Get(openidclient.LocaleURLParameter))
}
target.RawQuery = reqQuery.Encode()
canonicalRedirect := s.GetRedirect().Canonical(r)
ssoServerLoginURL := url.Login(target, canonicalRedirect)
logger.WithFields(log.Fields{
"redirect_to": ssoServerLoginURL,
"redirect_after_login": canonicalRedirect,
}).Info("login: redirecting to sso server")
http.Redirect(w, r, ssoServerLoginURL, http.StatusTemporaryRedirect)
}
func (s *SSOProxy) LoginCallback(w http.ResponseWriter, r *http.Request) {

View File

@@ -3,6 +3,7 @@ package handler_test
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
@@ -20,6 +21,407 @@ import (
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func TestLogin(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
resp := localLogin(t, rpClient, idp)
loginURL := resp.Location
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/login")
expectedCallbackURL, err := urlpkg.LoginCallback(req)
assert.NoError(t, err)
assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host))
assert.Equal(t, "/authorize", loginURL.Path)
assert.Equal(t, idp.OpenIDConfig.Client().ACRValues(), loginURL.Query().Get("acr_values"))
assert.Equal(t, idp.OpenIDConfig.Client().UILocales(), loginURL.Query().Get("ui_locales"))
assert.Equal(t, idp.OpenIDConfig.Client().ClientID(), loginURL.Query().Get("client_id"))
assert.Equal(t, expectedCallbackURL, loginURL.Query().Get("redirect_uri"))
assert.Equal(t, "S256", loginURL.Query().Get("code_challenge_method"))
assert.ElementsMatch(t, idp.OpenIDConfig.Client().Scopes(), strings.Split(loginURL.Query().Get("scope"), " "))
assert.NotEmpty(t, loginURL.Query().Get("state"))
assert.NotEmpty(t, loginURL.Query().Get("nonce"))
assert.NotEmpty(t, loginURL.Query().Get("code_challenge"))
resp = get(t, rpClient, loginURL.String())
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
callbackURL := resp.Location
assert.Equal(t, loginURL.Query().Get("state"), callbackURL.Query().Get("state"))
assert.NotEmpty(t, callbackURL.Query().Get("code"))
}
func TestCallback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
}
func TestCallback_SessionStateRequired(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
idp.OpenIDConfig.TestProvider.WithCheckSessionIFrameSupport(idp.ProviderServer.URL + "/checksession")
defer idp.Close()
rpClient := idp.RelyingPartyClient()
resp := authorize(t, rpClient, idp)
// Get callback URL after successful auth
params := resp.Location.Query()
sessionState := params.Get("session_state")
assert.NotEmpty(t, sessionState)
callback(t, rpClient, resp)
}
func TestLogout(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := selfInitiatedLogout(t, rpClient, idp)
// Get endsession endpoint after local logout
endsessionURL := resp.Location
idpserverURL, err := url.Parse(idp.ProviderServer.URL)
assert.NoError(t, err)
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout/callback")
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
assert.NoError(t, err)
endsessionParams := endsessionURL.Query()
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
assert.Equal(t, "/endsession", endsessionURL.Path)
assert.Equal(t, []string{expectedLogoutCallbackURL}, endsessionParams["post_logout_redirect_uri"])
assert.NotEmpty(t, endsessionParams["id_token_hint"])
}
func TestLogoutLocal(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
localLogout(t, rpClient, idp)
}
func TestLogoutCallback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
logout(t, rpClient, idp)
}
func TestFrontChannelLogout(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
idp.OpenIDConfig.TestProvider.WithFrontChannelLogoutSupport()
defer idp.Close()
rpClient := idp.RelyingPartyClient()
sessionCookie := login(t, rpClient, idp)
// Trigger front-channel logout
sid := func(r *http.Request) string {
r.AddCookie(sessionCookie)
ticket, err := session.GetTicket(r, idp.RelyingPartyHandler.GetCrypter())
assert.NoError(t, err)
data, err := idp.RelyingPartyHandler.GetSessions().Get(r, ticket)
assert.NoError(t, err)
return data.ExternalSessionID
}
frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/frontchannel")
assert.NoError(t, err)
req := idp.GetRequest(frontchannelLogoutURL.String())
values := url.Values{}
values.Add("sid", sid(req))
values.Add("iss", idp.OpenIDConfig.Provider().Issuer())
frontchannelLogoutURL.RawQuery = values.Encode()
resp := get(t, rpClient, frontchannelLogoutURL.String())
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestSessionRefresh(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
// get initial session info
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerboseWithRefresh
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
// wait until refresh cooldown has reached zero before refresh
waitForRefreshCooldownTimer(t, idp, rpClient)
resp = sessionRefresh(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var refreshedData session.MetadataVerboseWithRefresh
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
assert.NoError(t, err)
// session create and end times should be unchanged
assert.WithinDuration(t, data.Session.CreatedAt, refreshedData.Session.CreatedAt, 0)
assert.WithinDuration(t, data.Session.EndsAt, refreshedData.Session.EndsAt, 0)
// token expiration and refresh times should be later than before
assert.True(t, refreshedData.Tokens.ExpireAt.After(data.Tokens.ExpireAt))
assert.True(t, refreshedData.Tokens.RefreshedAt.After(data.Tokens.RefreshedAt))
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), refreshedData.Tokens.ExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), refreshedData.Tokens.RefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(refreshedData.Session.EndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(refreshedData.Tokens.ExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
// 1 second < next token refresh <= seconds until token expires
assert.LessOrEqual(t, refreshedData.Tokens.NextAutoRefreshInSeconds, refreshedData.Tokens.ExpireInSeconds)
assert.Greater(t, refreshedData.Tokens.NextAutoRefreshInSeconds, int64(1))
assert.True(t, refreshedData.Tokens.RefreshCooldown)
// 1 second < refresh cooldown <= minimum refresh interval
assert.LessOrEqual(t, refreshedData.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
assert.Greater(t, refreshedData.Tokens.RefreshCooldownSeconds, int64(1))
assert.True(t, data.Session.Active)
assert.True(t, refreshedData.Session.Active)
assert.True(t, data.Session.TimeoutAt.IsZero())
assert.True(t, refreshedData.Session.TimeoutAt.IsZero())
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
assert.Equal(t, int64(-1), refreshedData.Session.TimeoutInSeconds)
}
func TestSessionRefresh_Disabled(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = false
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionRefresh(t, idp, rpClient)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
}
func TestSessionRefresh_WithInactivity(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
cfg.Session.Inactivity = true
cfg.Session.InactivityTimeout = 10 * time.Minute
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
// get initial session info
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerboseWithRefresh
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
// wait until refresh cooldown has reached zero before refresh
waitForRefreshCooldownTimer(t, idp, rpClient)
resp = sessionRefresh(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var refreshedData session.MetadataVerboseWithRefresh
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
assert.NoError(t, err)
maxDelta := 5 * time.Second
assert.True(t, data.Session.Active)
assert.True(t, refreshedData.Session.Active)
assert.False(t, data.Session.TimeoutAt.IsZero())
assert.False(t, refreshedData.Session.TimeoutAt.IsZero())
expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout)
assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta)
assert.WithinDuration(t, expectedTimeoutAt, refreshedData.Session.TimeoutAt, maxDelta)
assert.True(t, refreshedData.Session.TimeoutAt.After(data.Session.TimeoutAt))
previousTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(previousTimeoutDuration), maxDelta)
refreshedTimeoutDuration := time.Duration(refreshedData.Session.TimeoutInSeconds) * time.Second
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(refreshedTimeoutDuration), maxDelta)
}
func TestSession(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Minute
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerbose
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
assert.True(t, data.Session.Active)
assert.True(t, data.Session.TimeoutAt.IsZero())
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
}
func TestSession_WithInactivity(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
cfg.Session.Inactivity = true
cfg.Session.InactivityTimeout = 10 * time.Minute
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerbose
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
maxDelta := 5 * time.Second
assert.True(t, data.Session.Active)
assert.False(t, data.Session.TimeoutAt.IsZero())
expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout)
assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta)
actualTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(actualTimeoutDuration), maxDelta)
}
func TestSession_WithRefresh(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Minute
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerboseWithRefresh
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
// 1 second < next token refresh <= seconds until token expires
assert.LessOrEqual(t, data.Tokens.NextAutoRefreshInSeconds, data.Tokens.ExpireInSeconds)
assert.Greater(t, data.Tokens.NextAutoRefreshInSeconds, int64(1))
assert.True(t, data.Tokens.RefreshCooldown)
// 1 second < refresh cooldown <= minimum refresh interval
assert.LessOrEqual(t, data.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
assert.Greater(t, data.Tokens.RefreshCooldownSeconds, int64(1))
assert.True(t, data.Session.Active)
assert.True(t, data.Session.TimeoutAt.IsZero())
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
}
func localLogin(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
// First, run /oauth2/login to set cookies
loginURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/login")

View File

@@ -1,116 +0,0 @@
package handler
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"time"
log "github.com/sirupsen/logrus"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/crypto"
errorhandler "github.com/nais/wonderwall/pkg/handler/error"
logentry "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/openid"
openidclient "github.com/nais/wonderwall/pkg/openid/client"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
const (
CookieLifetime = 1 * time.Hour
)
type LoginSource interface {
GetClient() *openidclient.Client
GetCookieOptsPathAware(r *http.Request) cookie.Options
GetCrypter() crypto.Crypter
GetErrorHandler() errorhandler.Handler
GetRedirect() urlpkg.Redirect
}
func Login(src LoginSource, w http.ResponseWriter, r *http.Request) {
canonicalRedirect := src.GetRedirect().Canonical(r)
login, err := src.GetClient().Login(r)
if err != nil {
if errors.Is(err, openidclient.ErrInvalidSecurityLevel) || errors.Is(err, openidclient.ErrInvalidLocale) {
src.GetErrorHandler().BadRequest(w, r, err)
} else {
src.GetErrorHandler().InternalError(w, r, err)
}
return
}
err = setLoginCookies(src, w, r, login.Cookie(canonicalRedirect))
if err != nil {
src.GetErrorHandler().InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err))
return
}
fields := log.Fields{
"redirect_after_login": canonicalRedirect,
}
logentry.LogEntryFrom(r).WithFields(fields).Info("login: redirecting to identity provider")
http.Redirect(w, r, login.AuthCodeURL(), http.StatusTemporaryRedirect)
}
type LoginSSOProxySource interface {
GetSSOServerURL() *url.URL
GetRedirect() urlpkg.Redirect
}
func LoginSSOProxy(src LoginSSOProxySource, w http.ResponseWriter, r *http.Request) {
logger := logentry.LogEntryFrom(r)
target := src.GetSSOServerURL()
targetQuery := target.Query()
// override default query parameters
reqQuery := r.URL.Query()
if reqQuery.Has(openidclient.SecurityLevelURLParameter) {
targetQuery.Set(openidclient.SecurityLevelURLParameter, reqQuery.Get(openidclient.SecurityLevelURLParameter))
}
if reqQuery.Has(openidclient.LocaleURLParameter) {
targetQuery.Set(openidclient.LocaleURLParameter, reqQuery.Get(openidclient.LocaleURLParameter))
}
target.RawQuery = reqQuery.Encode()
canonicalRedirect := src.GetRedirect().Canonical(r)
ssoServerLoginURL := urlpkg.Login(target, canonicalRedirect)
logger.WithFields(log.Fields{
"redirect_to": ssoServerLoginURL,
"redirect_after_login": canonicalRedirect,
}).Info("login: redirecting to sso server")
http.Redirect(w, r, ssoServerLoginURL, http.StatusTemporaryRedirect)
}
func setLoginCookies(src LoginSource, w http.ResponseWriter, r *http.Request, loginCookie *openid.LoginCookie) error {
loginCookieJson, err := json.Marshal(loginCookie)
if err != nil {
return fmt.Errorf("marshalling login cookie: %w", err)
}
opts := src.GetCookieOptsPathAware(r).
WithExpiresIn(CookieLifetime).
WithSameSite(http.SameSiteNoneMode)
value := string(loginCookieJson)
err = cookie.EncryptAndSet(w, cookie.Login, value, opts, src.GetCrypter())
if err != nil {
return err
}
// set a duplicate cookie without the SameSite value set for user agents that do not properly handle SameSite
err = cookie.EncryptAndSet(w, cookie.LoginLegacy, value, opts.WithSameSite(http.SameSiteDefaultMode), src.GetCrypter())
if err != nil {
return err
}
return nil
}

View File

@@ -1,125 +0,0 @@
package handler
import (
"context"
"errors"
"fmt"
"net/http"
"github.com/sethvargo/go-retry"
log "github.com/sirupsen/logrus"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/crypto"
errorhandler "github.com/nais/wonderwall/pkg/handler/error"
"github.com/nais/wonderwall/pkg/metrics"
logentry "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/openid"
openidclient "github.com/nais/wonderwall/pkg/openid/client"
retrypkg "github.com/nais/wonderwall/pkg/retry"
"github.com/nais/wonderwall/pkg/session"
"github.com/nais/wonderwall/pkg/url"
)
type LoginCallbackSource interface {
GetClient() *openidclient.Client
GetCookieOptions() cookie.Options
GetCookieOptsPathAware(r *http.Request) cookie.Options
GetCrypter() crypto.Crypter
GetErrorHandler() errorhandler.Handler
GetRedirect() url.Redirect
GetSessions() *session.Handler
GetSessionConfig() config.Session
}
func LoginCallback(src LoginCallbackSource, w http.ResponseWriter, r *http.Request) {
// unconditionally clear login cookie
clearLoginCookies(src, w, r)
loginCookie, err := openid.GetLoginCookie(r, src.GetCrypter())
if err != nil {
msg := "callback: fetching login cookie"
if errors.Is(err, http.ErrNoCookie) {
msg += ": fallback cookie not found (user might have blocked all cookies, or the callback route was accessed before the login route)"
}
src.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("%s: %w", msg, err))
return
}
loginCallback, err := src.GetClient().LoginCallback(r, loginCookie)
if err != nil {
src.GetErrorHandler().InternalError(w, r, err)
return
}
if err := loginCallback.IdentityProviderError(); err != nil {
src.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: %w", err))
return
}
if err := loginCallback.StateMismatchError(); err != nil {
src.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("callback: %w", err))
return
}
tokens, err := redeemValidTokens(r, loginCallback)
if err != nil {
src.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: redeeming tokens: %w", err))
return
}
sessionLifetime := src.GetSessionConfig().MaxLifetime
ticket, err := src.GetSessions().Create(r, tokens, sessionLifetime)
if err != nil {
src.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
return
}
opts := src.GetCookieOptsPathAware(r).
WithExpiresIn(sessionLifetime)
err = ticket.Set(w, opts, src.GetCrypter())
if err != nil {
src.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: setting session cookie: %w", err))
return
}
redirect := src.GetRedirect().Clean(r, loginCookie.Referer)
logSuccessfulLogin(r, tokens, redirect)
cookie.Clear(w, cookie.Retry, src.GetCookieOptsPathAware(r))
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
}
func clearLoginCookies(src LogoutCallbackSource, w http.ResponseWriter, r *http.Request) {
opts := src.GetCookieOptsPathAware(r)
cookie.Clear(w, cookie.Login, opts.WithSameSite(http.SameSiteNoneMode))
cookie.Clear(w, cookie.LoginLegacy, opts.WithSameSite(http.SameSiteDefaultMode))
}
func redeemValidTokens(r *http.Request, loginCallback *openidclient.LoginCallback) (*openid.Tokens, error) {
var tokens *openid.Tokens
var err error
retryable := func(ctx context.Context) error {
tokens, err = loginCallback.RedeemTokens(ctx)
return retry.RetryableError(err)
}
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
return nil, err
}
return tokens, nil
}
func logSuccessfulLogin(r *http.Request, tokens *openid.Tokens, referer string) {
fields := log.Fields{
"redirect_to": referer,
"jti": tokens.IDToken.GetJwtID(),
}
logentry.LogEntryFrom(r).WithFields(fields).Info("callback: successful login")
metrics.ObserveLogin()
}

View File

@@ -1,36 +0,0 @@
package handler_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/mock"
)
func TestCallback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
}
func TestCallback_SessionStateRequired(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
idp.OpenIDConfig.TestProvider.WithCheckSessionIFrameSupport(idp.ProviderServer.URL + "/checksession")
defer idp.Close()
rpClient := idp.RelyingPartyClient()
resp := authorize(t, rpClient, idp)
// Get callback URL after successful auth
params := resp.Location.Query()
sessionState := params.Get("session_state")
assert.NotEmpty(t, sessionState)
callback(t, rpClient, resp)
}

View File

@@ -1,48 +0,0 @@
package handler_test
import (
"fmt"
"net/http"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/mock"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func TestLogin(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
resp := localLogin(t, rpClient, idp)
loginURL := resp.Location
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/login")
expectedCallbackURL, err := urlpkg.LoginCallback(req)
assert.NoError(t, err)
assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host))
assert.Equal(t, "/authorize", loginURL.Path)
assert.Equal(t, idp.OpenIDConfig.Client().ACRValues(), loginURL.Query().Get("acr_values"))
assert.Equal(t, idp.OpenIDConfig.Client().UILocales(), loginURL.Query().Get("ui_locales"))
assert.Equal(t, idp.OpenIDConfig.Client().ClientID(), loginURL.Query().Get("client_id"))
assert.Equal(t, expectedCallbackURL, loginURL.Query().Get("redirect_uri"))
assert.Equal(t, "S256", loginURL.Query().Get("code_challenge_method"))
assert.ElementsMatch(t, idp.OpenIDConfig.Client().Scopes(), strings.Split(loginURL.Query().Get("scope"), " "))
assert.NotEmpty(t, loginURL.Query().Get("state"))
assert.NotEmpty(t, loginURL.Query().Get("nonce"))
assert.NotEmpty(t, loginURL.Query().Get("code_challenge"))
resp = get(t, rpClient, loginURL.String())
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
callbackURL := resp.Location
assert.Equal(t, loginURL.Query().Get("state"), callbackURL.Query().Get("state"))
assert.NotEmpty(t, callbackURL.Query().Get("code"))
}

View File

@@ -1,65 +0,0 @@
package handler
import (
"errors"
"fmt"
"net/http"
"github.com/nais/wonderwall/pkg/cookie"
errorhandler "github.com/nais/wonderwall/pkg/handler/error"
"github.com/nais/wonderwall/pkg/metrics"
logentry "github.com/nais/wonderwall/pkg/middleware"
openidclient "github.com/nais/wonderwall/pkg/openid/client"
"github.com/nais/wonderwall/pkg/session"
)
type LogoutSource interface {
GetClient() *openidclient.Client
GetCookieOptions() cookie.Options
GetCookieOptsPathAware(r *http.Request) cookie.Options
GetErrorHandler() errorhandler.Handler
GetSessions() *session.Handler
}
type LogoutOptions struct {
GlobalLogout bool
}
func Logout(src LogoutSource, w http.ResponseWriter, r *http.Request, opts LogoutOptions) {
logger := logentry.LogEntryFrom(r)
logout, err := src.GetClient().Logout(r)
if err != nil {
src.GetErrorHandler().InternalError(w, r, err)
return
}
var idToken string
sessions := src.GetSessions()
ticket, err := sessions.GetTicket(r)
if err == nil {
sessionData, err := sessions.Get(r, ticket)
if err == nil && sessionData != nil {
idToken = sessionData.IDToken
err = sessions.Destroy(r, ticket.Key())
if err != nil && !errors.Is(err, session.ErrKeyNotFound) {
src.GetErrorHandler().InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
return
}
logger.WithField("jti", sessionData.IDTokenJwtID).
Info("logout: successful local logout")
metrics.ObserveLogout(metrics.LogoutOperationLocal)
}
}
cookie.Clear(w, cookie.Session, src.GetCookieOptsPathAware(r))
if opts.GlobalLogout {
logger.Debug("logout: redirecting to identity provider for global/single-logout")
metrics.ObserveLogout(metrics.LogoutOperationSelfInitiated)
http.Redirect(w, r, logout.SingleLogoutURL(idToken), http.StatusTemporaryRedirect)
}
}

View File

@@ -1,22 +0,0 @@
package handler
import (
"net/http"
"github.com/nais/wonderwall/pkg/cookie"
logentry "github.com/nais/wonderwall/pkg/middleware"
openidclient "github.com/nais/wonderwall/pkg/openid/client"
)
type LogoutCallbackSource interface {
GetClient() *openidclient.Client
GetCookieOptsPathAware(r *http.Request) cookie.Options
}
func LogoutCallback(src LogoutCallbackSource, w http.ResponseWriter, r *http.Request) {
redirect := src.GetClient().LogoutCallback(r).PostLogoutRedirectURI()
cookie.Clear(w, cookie.Retry, src.GetCookieOptsPathAware(r))
logentry.LogEntryFrom(r).Debugf("logout/callback: redirecting to %s", redirect)
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
}

View File

@@ -1,17 +0,0 @@
package handler_test
import (
"testing"
"github.com/nais/wonderwall/pkg/mock"
)
func TestLogoutCallback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
logout(t, rpClient, idp)
}

View File

@@ -1,61 +0,0 @@
package handler
import (
"fmt"
"net/http"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/metrics"
mw "github.com/nais/wonderwall/pkg/middleware"
openidclient "github.com/nais/wonderwall/pkg/openid/client"
"github.com/nais/wonderwall/pkg/session"
)
type LogoutFrontChannelSource interface {
GetClient() *openidclient.Client
GetCookieOptions() cookie.Options
GetCookieOptsPathAware(r *http.Request) cookie.Options
GetSessions() *session.Handler
}
func LogoutFrontChannel(src LogoutFrontChannelSource, w http.ResponseWriter, r *http.Request) {
logger := mw.LogEntryFrom(r)
// Unconditionally destroy all local references to the session.
cookie.Clear(w, cookie.Session, src.GetCookieOptsPathAware(r))
sessions := src.GetSessions()
client := src.GetClient()
key, err := getSessionKey(r, sessions, client)
if err != nil {
logger.Debugf("front-channel logout: getting session key: %+v; ignoring", err)
w.WriteHeader(http.StatusAccepted)
return
}
err = sessions.Destroy(r, key)
if err != nil {
logger.Warnf("front-channel logout: destroying session: %+v", err)
w.WriteHeader(http.StatusAccepted)
return
}
cookie.Clear(w, cookie.Retry, src.GetCookieOptsPathAware(r))
metrics.ObserveLogout(metrics.LogoutOperationFrontChannel)
w.WriteHeader(http.StatusOK)
}
func getSessionKey(r *http.Request, sessions *session.Handler, client *openidclient.Client) (string, error) {
logoutFrontchannel := client.LogoutFrontchannel(r)
if logoutFrontchannel.MissingSidParameter() {
ticket, err := sessions.GetTicket(r)
if err != nil {
return ticket.Key(), nil
}
return "", fmt.Errorf("neither sid parameter nor session ticket found in request: %w", err)
}
sid := logoutFrontchannel.Sid()
return sessions.Key(sid), nil
}

View File

@@ -1,48 +0,0 @@
package handler_test
import (
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/session"
)
func TestFrontChannelLogout(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
idp.OpenIDConfig.TestProvider.WithFrontChannelLogoutSupport()
defer idp.Close()
rpClient := idp.RelyingPartyClient()
sessionCookie := login(t, rpClient, idp)
// Trigger front-channel logout
sid := func(r *http.Request) string {
r.AddCookie(sessionCookie)
ticket, err := session.GetTicket(r, idp.RelyingPartyHandler.GetCrypter())
assert.NoError(t, err)
data, err := idp.RelyingPartyHandler.GetSessions().Get(r, ticket)
assert.NoError(t, err)
return data.ExternalSessionID
}
frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/frontchannel")
assert.NoError(t, err)
req := idp.GetRequest(frontchannelLogoutURL.String())
values := url.Values{}
values.Add("sid", sid(req))
values.Add("iss", idp.OpenIDConfig.Provider().Issuer())
frontchannelLogoutURL.RawQuery = values.Encode()
resp := get(t, rpClient, frontchannelLogoutURL.String())
assert.Equal(t, http.StatusOK, resp.StatusCode)
}

View File

@@ -1,49 +0,0 @@
package handler_test
import (
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/mock"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func TestLogout(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := selfInitiatedLogout(t, rpClient, idp)
// Get endsession endpoint after local logout
endsessionURL := resp.Location
idpserverURL, err := url.Parse(idp.ProviderServer.URL)
assert.NoError(t, err)
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout/callback")
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
assert.NoError(t, err)
endsessionParams := endsessionURL.Query()
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
assert.Equal(t, "/endsession", endsessionURL.Path)
assert.Equal(t, []string{expectedLogoutCallbackURL}, endsessionParams["post_logout_redirect_uri"])
assert.NotEmpty(t, endsessionParams["id_token_hint"])
}
func TestLogoutLocal(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
localLogout(t, rpClient, idp)
}

View File

@@ -1,55 +0,0 @@
package handler
import (
"encoding/json"
"errors"
"net/http"
"github.com/nais/wonderwall/pkg/config"
mw "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/session"
)
type SessionSource interface {
GetSessions() *session.Handler
GetSessionConfig() config.Session
}
func Session(src SessionSource, w http.ResponseWriter, r *http.Request) {
logger := mw.LogEntryFrom(r)
ticket, err := src.GetSessions().GetTicket(r)
if err != nil {
logger.Infof("session/refresh: getting ticket: %+v", err)
w.WriteHeader(http.StatusUnauthorized)
return
}
data, err := src.GetSessions().Get(r, ticket)
if err != nil {
switch {
case errors.Is(err, session.ErrInvalidSession), errors.Is(err, session.ErrKeyNotFound):
logger.Infof("session/info: getting session: %+v", err)
w.WriteHeader(http.StatusUnauthorized)
return
default:
logger.Warnf("session/info: getting session: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}
w.Header().Set("Content-Type", "application/json")
if src.GetSessionConfig().Refresh {
err = json.NewEncoder(w).Encode(data.Metadata.VerboseWithRefresh())
} else {
err = json.NewEncoder(w).Encode(data.Metadata.Verbose())
}
if err != nil {
logger.Warnf("session/info: marshalling metadata: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}

View File

@@ -1,59 +0,0 @@
package handler
import (
"encoding/json"
"errors"
"net/http"
mw "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/session"
)
type SessionRefreshSource interface {
GetSessions() *session.Handler
}
func SessionRefresh(src SessionRefreshSource, w http.ResponseWriter, r *http.Request) {
logger := mw.LogEntryFrom(r)
ticket, err := src.GetSessions().GetTicket(r)
if err != nil {
logger.Infof("session/refresh: getting ticket: %+v", err)
w.WriteHeader(http.StatusUnauthorized)
return
}
data, err := src.GetSessions().Get(r, ticket)
if err != nil {
switch {
case errors.Is(err, session.ErrInvalidSession), errors.Is(err, session.ErrKeyNotFound):
logger.Infof("session/refresh: getting session: %+v", err)
w.WriteHeader(http.StatusUnauthorized)
default:
logger.Warnf("session/refresh: getting session: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
}
return
}
data, err = src.GetSessions().Refresh(r, ticket, data)
if err != nil {
if errors.Is(err, session.ErrInvalidIdpState) || errors.Is(err, session.ErrInvalidSession) {
logger.Infof("session/refresh: refreshing: %+v", err)
w.WriteHeader(http.StatusUnauthorized)
return
}
logger.Warnf("session/refresh: refreshing: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(data.Metadata.VerboseWithRefresh())
if err != nil {
logger.Warnf("session/refresh: marshalling metadata: %+v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
}

View File

@@ -1,150 +0,0 @@
package handler_test
import (
"encoding/json"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/session"
)
func TestSessionRefresh(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
// get initial session info
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerboseWithRefresh
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
// wait until refresh cooldown has reached zero before refresh
waitForRefreshCooldownTimer(t, idp, rpClient)
resp = sessionRefresh(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var refreshedData session.MetadataVerboseWithRefresh
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
assert.NoError(t, err)
// session create and end times should be unchanged
assert.WithinDuration(t, data.Session.CreatedAt, refreshedData.Session.CreatedAt, 0)
assert.WithinDuration(t, data.Session.EndsAt, refreshedData.Session.EndsAt, 0)
// token expiration and refresh times should be later than before
assert.True(t, refreshedData.Tokens.ExpireAt.After(data.Tokens.ExpireAt))
assert.True(t, refreshedData.Tokens.RefreshedAt.After(data.Tokens.RefreshedAt))
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), refreshedData.Tokens.ExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), refreshedData.Tokens.RefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(refreshedData.Session.EndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(refreshedData.Tokens.ExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
// 1 second < next token refresh <= seconds until token expires
assert.LessOrEqual(t, refreshedData.Tokens.NextAutoRefreshInSeconds, refreshedData.Tokens.ExpireInSeconds)
assert.Greater(t, refreshedData.Tokens.NextAutoRefreshInSeconds, int64(1))
assert.True(t, refreshedData.Tokens.RefreshCooldown)
// 1 second < refresh cooldown <= minimum refresh interval
assert.LessOrEqual(t, refreshedData.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
assert.Greater(t, refreshedData.Tokens.RefreshCooldownSeconds, int64(1))
assert.True(t, data.Session.Active)
assert.True(t, refreshedData.Session.Active)
assert.True(t, data.Session.TimeoutAt.IsZero())
assert.True(t, refreshedData.Session.TimeoutAt.IsZero())
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
assert.Equal(t, int64(-1), refreshedData.Session.TimeoutInSeconds)
}
func TestSessionRefresh_Disabled(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = false
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionRefresh(t, idp, rpClient)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
}
func TestSessionRefresh_WithInactivity(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
cfg.Session.Inactivity = true
cfg.Session.InactivityTimeout = 10 * time.Minute
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
// get initial session info
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerboseWithRefresh
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
// wait until refresh cooldown has reached zero before refresh
waitForRefreshCooldownTimer(t, idp, rpClient)
resp = sessionRefresh(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var refreshedData session.MetadataVerboseWithRefresh
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
assert.NoError(t, err)
maxDelta := 5 * time.Second
assert.True(t, data.Session.Active)
assert.True(t, refreshedData.Session.Active)
assert.False(t, data.Session.TimeoutAt.IsZero())
assert.False(t, refreshedData.Session.TimeoutAt.IsZero())
expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout)
assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta)
assert.WithinDuration(t, expectedTimeoutAt, refreshedData.Session.TimeoutAt, maxDelta)
assert.True(t, refreshedData.Session.TimeoutAt.After(data.Session.TimeoutAt))
previousTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(previousTimeoutDuration), maxDelta)
refreshedTimeoutDuration := time.Duration(refreshedData.Session.TimeoutInSeconds) * time.Second
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(refreshedTimeoutDuration), maxDelta)
}

View File

@@ -1,131 +0,0 @@
package handler_test
import (
"encoding/json"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/session"
)
func TestSession(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Minute
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerbose
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
assert.True(t, data.Session.Active)
assert.True(t, data.Session.TimeoutAt.IsZero())
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
}
func TestSession_WithInactivity(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
cfg.Session.Inactivity = true
cfg.Session.InactivityTimeout = 10 * time.Minute
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerbose
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
maxDelta := 5 * time.Second
assert.True(t, data.Session.Active)
assert.False(t, data.Session.TimeoutAt.IsZero())
expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout)
assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta)
actualTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(actualTimeoutDuration), maxDelta)
}
func TestSession_WithRefresh(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Minute
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerboseWithRefresh
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
// 1 second < next token refresh <= seconds until token expires
assert.LessOrEqual(t, data.Tokens.NextAutoRefreshInSeconds, data.Tokens.ExpireInSeconds)
assert.Greater(t, data.Tokens.NextAutoRefreshInSeconds, int64(1))
assert.True(t, data.Tokens.RefreshCooldown)
// 1 second < refresh cooldown <= minimum refresh interval
assert.LessOrEqual(t, data.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
assert.Greater(t, data.Tokens.RefreshCooldownSeconds, int64(1))
assert.True(t, data.Session.Active)
assert.True(t, data.Session.TimeoutAt.IsZero())
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
}

View File

@@ -3,12 +3,15 @@ package client
import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/config"
"github.com/nais/wonderwall/pkg/strings"
@@ -52,11 +55,11 @@ func NewLogin(c *Client, r *http.Request) (*Login, error) {
return nil, fmt.Errorf("generating auth code url: %w", err)
}
cookie := params.cookie(callbackURL)
loginCookie := params.cookie(callbackURL)
return &Login{
authCodeURL: url,
cookie: cookie,
cookie: loginCookie,
params: params,
}, nil
}
@@ -79,11 +82,6 @@ func (l *Login) CodeVerifier() string {
return l.params.CodeVerifier
}
func (l *Login) Cookie(canonicalRedirect string) *openid.LoginCookie {
l.cookie.Referer = canonicalRedirect
return l.cookie
}
func (l *Login) Nonce() string {
return l.params.Nonce
}
@@ -92,6 +90,30 @@ func (l *Login) State() string {
return l.params.State
}
func (l *Login) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter, canonicalRedirect string) error {
l.cookie.Referer = canonicalRedirect
loginCookieJson, err := json.Marshal(l.cookie)
if err != nil {
return fmt.Errorf("marshalling login cookie: %w", err)
}
value := string(loginCookieJson)
err = cookie.EncryptAndSet(w, cookie.Login, value, opts, crypter)
if err != nil {
return err
}
// set a duplicate cookie without the SameSite value set for user agents that do not properly handle SameSite
err = cookie.EncryptAndSet(w, cookie.LoginLegacy, value, opts.WithSameSite(http.SameSiteDefaultMode), crypter)
if err != nil {
return err
}
return nil
}
type loginParameters struct {
*Client
CodeVerifier string

View File

@@ -103,6 +103,8 @@ func TestLogin_URL(t *testing.T) {
assert.ElementsMatch(t, query["code_challenge"], []string{result.CodeChallenge()})
assert.ElementsMatch(t, query["code_challenge_method"], []string{"S256"})
assert.Equal(t, client.CodeChallenge(result.CodeVerifier()), result.CodeChallenge())
if test.extraParams != nil {
for key, value := range test.extraParams {
assert.Contains(t, query, key)