mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-13 03:47:02 +00:00
refactor: consolidate handlers
This commit is contained in:
@@ -1,19 +1,29 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
urllib "net/url"
|
||||
"time"
|
||||
|
||||
"github.com/sethvargo/go-retry"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/handler/autologin"
|
||||
errorhandler "github.com/nais/wonderwall/pkg/handler/error"
|
||||
"github.com/nais/wonderwall/pkg/ingress"
|
||||
"github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/metrics"
|
||||
mw "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
openidclient "github.com/nais/wonderwall/pkg/openid/client"
|
||||
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
|
||||
retrypkg "github.com/nais/wonderwall/pkg/retry"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"github.com/nais/wonderwall/pkg/url"
|
||||
@@ -34,6 +44,10 @@ type Standalone struct {
|
||||
UpstreamProxy *ReverseProxy
|
||||
}
|
||||
|
||||
type LogoutOptions struct {
|
||||
GlobalLogout bool
|
||||
}
|
||||
|
||||
func NewStandalone(
|
||||
cfg *config.Config,
|
||||
cookieOpts cookie.Options,
|
||||
@@ -114,7 +128,7 @@ func (s *Standalone) GetIngresses() *ingress.Ingresses {
|
||||
}
|
||||
|
||||
func (s *Standalone) GetPath(r *http.Request) string {
|
||||
path, ok := middleware.PathFrom(r.Context())
|
||||
path, ok := mw.PathFrom(r.Context())
|
||||
if !ok {
|
||||
path = s.Ingresses.MatchingPath(r)
|
||||
}
|
||||
@@ -135,37 +149,245 @@ func (s *Standalone) GetSessionConfig() config.Session {
|
||||
}
|
||||
|
||||
func (s *Standalone) Login(w http.ResponseWriter, r *http.Request) {
|
||||
Login(s, w, r)
|
||||
canonicalRedirect := s.GetRedirect().Canonical(r)
|
||||
login, err := s.GetClient().Login(r)
|
||||
if err != nil {
|
||||
if errors.Is(err, openidclient.ErrInvalidSecurityLevel) || errors.Is(err, openidclient.ErrInvalidLocale) {
|
||||
s.GetErrorHandler().BadRequest(w, r, err)
|
||||
} else {
|
||||
s.GetErrorHandler().InternalError(w, r, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
opts := s.GetCookieOptsPathAware(r).
|
||||
WithExpiresIn(1 * time.Hour).
|
||||
WithSameSite(http.SameSiteNoneMode)
|
||||
err = login.SetCookie(w, opts, s.GetCrypter(), canonicalRedirect)
|
||||
if err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
fields := log.Fields{
|
||||
"redirect_after_login": canonicalRedirect,
|
||||
}
|
||||
mw.LogEntryFrom(r).WithFields(fields).Info("login: redirecting to identity provider")
|
||||
http.Redirect(w, r, login.AuthCodeURL(), http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (s *Standalone) LoginCallback(w http.ResponseWriter, r *http.Request) {
|
||||
LoginCallback(s, w, r)
|
||||
opts := s.GetCookieOptsPathAware(r)
|
||||
|
||||
// unconditionally clear login cookies
|
||||
cookie.Clear(w, cookie.Login, opts.WithSameSite(http.SameSiteNoneMode))
|
||||
cookie.Clear(w, cookie.LoginLegacy, opts.WithSameSite(http.SameSiteDefaultMode))
|
||||
|
||||
loginCookie, err := openid.GetLoginCookie(r, s.GetCrypter())
|
||||
if err != nil {
|
||||
msg := "callback: fetching login cookie"
|
||||
if errors.Is(err, http.ErrNoCookie) {
|
||||
msg += ": fallback cookie not found (user might have blocked all cookies, or the callback route was accessed before the login route)"
|
||||
}
|
||||
s.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("%s: %w", msg, err))
|
||||
return
|
||||
}
|
||||
|
||||
loginCallback, err := s.GetClient().LoginCallback(r, loginCookie)
|
||||
if err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := loginCallback.IdentityProviderError(); err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := loginCallback.StateMismatchError(); err != nil {
|
||||
s.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("callback: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
var tokens *openid.Tokens
|
||||
err = retry.Do(r.Context(), retrypkg.DefaultBackoff, func(ctx context.Context) error {
|
||||
tokens, err = loginCallback.RedeemTokens(ctx)
|
||||
return retry.RetryableError(err)
|
||||
})
|
||||
if err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: redeeming tokens: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
sessionLifetime := s.GetSessionConfig().MaxLifetime
|
||||
|
||||
ticket, err := s.GetSessions().Create(r, tokens, sessionLifetime)
|
||||
if err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
err = ticket.Set(w, opts.WithExpiresIn(sessionLifetime), s.GetCrypter())
|
||||
if err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: setting session cookie: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
redirect := s.GetRedirect().Clean(r, loginCookie.Referer)
|
||||
|
||||
fields := log.Fields{
|
||||
"redirect_to": redirect,
|
||||
"jti": tokens.IDToken.GetJwtID(),
|
||||
}
|
||||
|
||||
mw.LogEntryFrom(r).WithFields(fields).Info("callback: successful login")
|
||||
metrics.ObserveLogin()
|
||||
cookie.Clear(w, cookie.Retry, s.GetCookieOptsPathAware(r))
|
||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (s *Standalone) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
opts := LogoutOptions{
|
||||
GlobalLogout: true,
|
||||
}
|
||||
Logout(s, w, r, opts)
|
||||
s.logout(w, r, opts)
|
||||
}
|
||||
|
||||
func (s *Standalone) LogoutLocal(w http.ResponseWriter, r *http.Request) {
|
||||
opts := LogoutOptions{
|
||||
GlobalLogout: false,
|
||||
}
|
||||
Logout(s, w, r, opts)
|
||||
s.logout(w, r, opts)
|
||||
}
|
||||
|
||||
func (s *Standalone) logout(w http.ResponseWriter, r *http.Request, opts LogoutOptions) {
|
||||
logger := mw.LogEntryFrom(r)
|
||||
logout, err := s.GetClient().Logout(r)
|
||||
if err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
var idToken string
|
||||
|
||||
sessions := s.GetSessions()
|
||||
|
||||
ticket, err := sessions.GetTicket(r)
|
||||
if err == nil {
|
||||
sessionData, err := sessions.Get(r, ticket)
|
||||
if err == nil && sessionData != nil {
|
||||
idToken = sessionData.IDToken
|
||||
|
||||
err = sessions.Destroy(r, ticket.Key())
|
||||
if err != nil && !errors.Is(err, session.ErrKeyNotFound) {
|
||||
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
logger.WithField("jti", sessionData.IDTokenJwtID).
|
||||
Info("logout: successful local logout")
|
||||
metrics.ObserveLogout(metrics.LogoutOperationLocal)
|
||||
}
|
||||
}
|
||||
|
||||
cookie.Clear(w, cookie.Session, s.GetCookieOptsPathAware(r))
|
||||
|
||||
if opts.GlobalLogout {
|
||||
logger.Debug("logout: redirecting to identity provider for global/single-logout")
|
||||
metrics.ObserveLogout(metrics.LogoutOperationSelfInitiated)
|
||||
http.Redirect(w, r, logout.SingleLogoutURL(idToken), http.StatusTemporaryRedirect)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Standalone) LogoutCallback(w http.ResponseWriter, r *http.Request) {
|
||||
LogoutCallback(s, w, r)
|
||||
redirect := s.GetClient().LogoutCallback(r).PostLogoutRedirectURI()
|
||||
|
||||
cookie.Clear(w, cookie.Retry, s.GetCookieOptsPathAware(r))
|
||||
mw.LogEntryFrom(r).Debugf("logout/callback: redirecting to %s", redirect)
|
||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (s *Standalone) LogoutFrontChannel(w http.ResponseWriter, r *http.Request) {
|
||||
LogoutFrontChannel(s, w, r)
|
||||
logger := mw.LogEntryFrom(r)
|
||||
|
||||
// Unconditionally destroy all local references to the session.
|
||||
cookie.Clear(w, cookie.Session, s.GetCookieOptsPathAware(r))
|
||||
|
||||
sessions := s.GetSessions()
|
||||
client := s.GetClient()
|
||||
|
||||
getSessionKey := func(r *http.Request) (string, error) {
|
||||
lfc := client.LogoutFrontchannel(r)
|
||||
|
||||
if lfc.MissingSidParameter() {
|
||||
ticket, err := sessions.GetTicket(r)
|
||||
if err != nil {
|
||||
return ticket.Key(), nil
|
||||
}
|
||||
return "", fmt.Errorf("neither sid parameter nor session ticket found in request: %w", err)
|
||||
}
|
||||
|
||||
sid := lfc.Sid()
|
||||
return sessions.Key(sid), nil
|
||||
}
|
||||
|
||||
key, err := getSessionKey(r)
|
||||
if err != nil {
|
||||
logger.Debugf("front-channel logout: getting session key: %+v; ignoring", err)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
err = sessions.Destroy(r, key)
|
||||
if err != nil {
|
||||
logger.Warnf("front-channel logout: destroying session: %+v", err)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
cookie.Clear(w, cookie.Retry, s.GetCookieOptsPathAware(r))
|
||||
metrics.ObserveLogout(metrics.LogoutOperationFrontChannel)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (s *Standalone) Session(w http.ResponseWriter, r *http.Request) {
|
||||
Session(s, w, r)
|
||||
logger := mw.LogEntryFrom(r)
|
||||
|
||||
ticket, err := s.GetSessions().GetTicket(r)
|
||||
if err != nil {
|
||||
logger.Infof("session/refresh: getting ticket: %+v", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := s.GetSessions().Get(r, ticket)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, session.ErrInvalidSession), errors.Is(err, session.ErrKeyNotFound):
|
||||
logger.Infof("session/info: getting session: %+v", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
default:
|
||||
logger.Warnf("session/info: getting session: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if s.GetSessionConfig().Refresh {
|
||||
err = json.NewEncoder(w).Encode(data.Metadata.VerboseWithRefresh())
|
||||
} else {
|
||||
err = json.NewEncoder(w).Encode(data.Metadata.Verbose())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Warnf("session/info: marshalling metadata: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Standalone) SessionRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -174,7 +396,48 @@ func (s *Standalone) SessionRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
SessionRefresh(s, w, r)
|
||||
logger := mw.LogEntryFrom(r)
|
||||
|
||||
ticket, err := s.GetSessions().GetTicket(r)
|
||||
if err != nil {
|
||||
logger.Infof("session/refresh: getting ticket: %+v", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := s.GetSessions().Get(r, ticket)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, session.ErrInvalidSession), errors.Is(err, session.ErrKeyNotFound):
|
||||
logger.Infof("session/refresh: getting session: %+v", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
default:
|
||||
logger.Warnf("session/refresh: getting session: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
data, err = s.GetSessions().Refresh(r, ticket, data)
|
||||
if err != nil {
|
||||
if errors.Is(err, session.ErrInvalidIdpState) || errors.Is(err, session.ErrInvalidSession) {
|
||||
logger.Infof("session/refresh: refreshing: %+v", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Warnf("session/refresh: refreshing: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(data.Metadata.VerboseWithRefresh())
|
||||
if err != nil {
|
||||
logger.Warnf("session/refresh: marshalling metadata: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Standalone) ReverseProxy(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -5,8 +5,11 @@ import (
|
||||
"net/http"
|
||||
urllib "net/url"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/ingress"
|
||||
logentry "github.com/nais/wonderwall/pkg/middleware"
|
||||
openidclient "github.com/nais/wonderwall/pkg/openid/client"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
"github.com/nais/wonderwall/pkg/router/paths"
|
||||
@@ -56,7 +59,31 @@ func NewSSOProxy(cfg *config.Config) (*SSOProxy, error) {
|
||||
}
|
||||
|
||||
func (s *SSOProxy) Login(w http.ResponseWriter, r *http.Request) {
|
||||
LoginSSOProxy(s, w, r)
|
||||
logger := logentry.LogEntryFrom(r)
|
||||
|
||||
target := s.GetSSOServerURL()
|
||||
targetQuery := target.Query()
|
||||
|
||||
// override default query parameters
|
||||
reqQuery := r.URL.Query()
|
||||
if reqQuery.Has(openidclient.SecurityLevelURLParameter) {
|
||||
targetQuery.Set(openidclient.SecurityLevelURLParameter, reqQuery.Get(openidclient.SecurityLevelURLParameter))
|
||||
}
|
||||
if reqQuery.Has(openidclient.LocaleURLParameter) {
|
||||
targetQuery.Set(openidclient.LocaleURLParameter, reqQuery.Get(openidclient.LocaleURLParameter))
|
||||
}
|
||||
|
||||
target.RawQuery = reqQuery.Encode()
|
||||
|
||||
canonicalRedirect := s.GetRedirect().Canonical(r)
|
||||
ssoServerLoginURL := url.Login(target, canonicalRedirect)
|
||||
|
||||
logger.WithFields(log.Fields{
|
||||
"redirect_to": ssoServerLoginURL,
|
||||
"redirect_after_login": canonicalRedirect,
|
||||
}).Info("login: redirecting to sso server")
|
||||
|
||||
http.Redirect(w, r, ssoServerLoginURL, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (s *SSOProxy) LoginCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package handler_test
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -20,6 +21,407 @@ import (
|
||||
urlpkg "github.com/nais/wonderwall/pkg/url"
|
||||
)
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
|
||||
resp := localLogin(t, rpClient, idp)
|
||||
loginURL := resp.Location
|
||||
|
||||
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/login")
|
||||
|
||||
expectedCallbackURL, err := urlpkg.LoginCallback(req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host))
|
||||
assert.Equal(t, "/authorize", loginURL.Path)
|
||||
assert.Equal(t, idp.OpenIDConfig.Client().ACRValues(), loginURL.Query().Get("acr_values"))
|
||||
assert.Equal(t, idp.OpenIDConfig.Client().UILocales(), loginURL.Query().Get("ui_locales"))
|
||||
assert.Equal(t, idp.OpenIDConfig.Client().ClientID(), loginURL.Query().Get("client_id"))
|
||||
assert.Equal(t, expectedCallbackURL, loginURL.Query().Get("redirect_uri"))
|
||||
assert.Equal(t, "S256", loginURL.Query().Get("code_challenge_method"))
|
||||
assert.ElementsMatch(t, idp.OpenIDConfig.Client().Scopes(), strings.Split(loginURL.Query().Get("scope"), " "))
|
||||
assert.NotEmpty(t, loginURL.Query().Get("state"))
|
||||
assert.NotEmpty(t, loginURL.Query().Get("nonce"))
|
||||
assert.NotEmpty(t, loginURL.Query().Get("code_challenge"))
|
||||
|
||||
resp = get(t, rpClient, loginURL.String())
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
callbackURL := resp.Location
|
||||
assert.Equal(t, loginURL.Query().Get("state"), callbackURL.Query().Get("state"))
|
||||
assert.NotEmpty(t, callbackURL.Query().Get("code"))
|
||||
}
|
||||
|
||||
func TestCallback(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
}
|
||||
|
||||
func TestCallback_SessionStateRequired(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.OpenIDConfig.TestProvider.WithCheckSessionIFrameSupport(idp.ProviderServer.URL + "/checksession")
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
|
||||
resp := authorize(t, rpClient, idp)
|
||||
|
||||
// Get callback URL after successful auth
|
||||
params := resp.Location.Query()
|
||||
sessionState := params.Get("session_state")
|
||||
assert.NotEmpty(t, sessionState)
|
||||
|
||||
callback(t, rpClient, resp)
|
||||
}
|
||||
|
||||
func TestLogout(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := selfInitiatedLogout(t, rpClient, idp)
|
||||
|
||||
// Get endsession endpoint after local logout
|
||||
endsessionURL := resp.Location
|
||||
|
||||
idpserverURL, err := url.Parse(idp.ProviderServer.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout/callback")
|
||||
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
endsessionParams := endsessionURL.Query()
|
||||
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
|
||||
assert.Equal(t, "/endsession", endsessionURL.Path)
|
||||
assert.Equal(t, []string{expectedLogoutCallbackURL}, endsessionParams["post_logout_redirect_uri"])
|
||||
assert.NotEmpty(t, endsessionParams["id_token_hint"])
|
||||
}
|
||||
|
||||
func TestLogoutLocal(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
localLogout(t, rpClient, idp)
|
||||
}
|
||||
|
||||
func TestLogoutCallback(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
logout(t, rpClient, idp)
|
||||
}
|
||||
|
||||
func TestFrontChannelLogout(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.OpenIDConfig.TestProvider.WithFrontChannelLogoutSupport()
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
sessionCookie := login(t, rpClient, idp)
|
||||
|
||||
// Trigger front-channel logout
|
||||
sid := func(r *http.Request) string {
|
||||
r.AddCookie(sessionCookie)
|
||||
|
||||
ticket, err := session.GetTicket(r, idp.RelyingPartyHandler.GetCrypter())
|
||||
assert.NoError(t, err)
|
||||
|
||||
data, err := idp.RelyingPartyHandler.GetSessions().Get(r, ticket)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return data.ExternalSessionID
|
||||
}
|
||||
|
||||
frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/frontchannel")
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := idp.GetRequest(frontchannelLogoutURL.String())
|
||||
|
||||
values := url.Values{}
|
||||
values.Add("sid", sid(req))
|
||||
values.Add("iss", idp.OpenIDConfig.Provider().Issuer())
|
||||
frontchannelLogoutURL.RawQuery = values.Encode()
|
||||
|
||||
resp := get(t, rpClient, frontchannelLogoutURL.String())
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestSessionRefresh(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = true
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Second
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
// get initial session info
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var data session.MetadataVerboseWithRefresh
|
||||
err := json.Unmarshal([]byte(resp.Body), &data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// wait until refresh cooldown has reached zero before refresh
|
||||
waitForRefreshCooldownTimer(t, idp, rpClient)
|
||||
|
||||
resp = sessionRefresh(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var refreshedData session.MetadataVerboseWithRefresh
|
||||
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// session create and end times should be unchanged
|
||||
assert.WithinDuration(t, data.Session.CreatedAt, refreshedData.Session.CreatedAt, 0)
|
||||
assert.WithinDuration(t, data.Session.EndsAt, refreshedData.Session.EndsAt, 0)
|
||||
|
||||
// token expiration and refresh times should be later than before
|
||||
assert.True(t, refreshedData.Tokens.ExpireAt.After(data.Tokens.ExpireAt))
|
||||
assert.True(t, refreshedData.Tokens.RefreshedAt.After(data.Tokens.RefreshedAt))
|
||||
|
||||
allowedSkew := 5 * time.Second
|
||||
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), refreshedData.Tokens.ExpireAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now(), refreshedData.Tokens.RefreshedAt, allowedSkew)
|
||||
|
||||
sessionEndDuration := time.Duration(refreshedData.Session.EndsInSeconds) * time.Second
|
||||
// 1 second < time until session ends <= configured max session lifetime
|
||||
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
|
||||
assert.Greater(t, sessionEndDuration, time.Second)
|
||||
|
||||
tokenExpiryDuration := time.Duration(refreshedData.Tokens.ExpireInSeconds) * time.Second
|
||||
// 1 second < time until token expires <= max duration for tokens from IDP
|
||||
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
|
||||
assert.Greater(t, tokenExpiryDuration, time.Second)
|
||||
|
||||
// 1 second < next token refresh <= seconds until token expires
|
||||
assert.LessOrEqual(t, refreshedData.Tokens.NextAutoRefreshInSeconds, refreshedData.Tokens.ExpireInSeconds)
|
||||
assert.Greater(t, refreshedData.Tokens.NextAutoRefreshInSeconds, int64(1))
|
||||
|
||||
assert.True(t, refreshedData.Tokens.RefreshCooldown)
|
||||
// 1 second < refresh cooldown <= minimum refresh interval
|
||||
assert.LessOrEqual(t, refreshedData.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
|
||||
assert.Greater(t, refreshedData.Tokens.RefreshCooldownSeconds, int64(1))
|
||||
|
||||
assert.True(t, data.Session.Active)
|
||||
assert.True(t, refreshedData.Session.Active)
|
||||
|
||||
assert.True(t, data.Session.TimeoutAt.IsZero())
|
||||
assert.True(t, refreshedData.Session.TimeoutAt.IsZero())
|
||||
|
||||
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
|
||||
assert.Equal(t, int64(-1), refreshedData.Session.TimeoutInSeconds)
|
||||
}
|
||||
|
||||
func TestSessionRefresh_Disabled(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = false
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Second
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := sessionRefresh(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestSessionRefresh_WithInactivity(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = true
|
||||
cfg.Session.Inactivity = true
|
||||
cfg.Session.InactivityTimeout = 10 * time.Minute
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Second
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
// get initial session info
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var data session.MetadataVerboseWithRefresh
|
||||
err := json.Unmarshal([]byte(resp.Body), &data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// wait until refresh cooldown has reached zero before refresh
|
||||
waitForRefreshCooldownTimer(t, idp, rpClient)
|
||||
|
||||
resp = sessionRefresh(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var refreshedData session.MetadataVerboseWithRefresh
|
||||
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
maxDelta := 5 * time.Second
|
||||
|
||||
assert.True(t, data.Session.Active)
|
||||
assert.True(t, refreshedData.Session.Active)
|
||||
|
||||
assert.False(t, data.Session.TimeoutAt.IsZero())
|
||||
assert.False(t, refreshedData.Session.TimeoutAt.IsZero())
|
||||
|
||||
expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout)
|
||||
assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta)
|
||||
assert.WithinDuration(t, expectedTimeoutAt, refreshedData.Session.TimeoutAt, maxDelta)
|
||||
|
||||
assert.True(t, refreshedData.Session.TimeoutAt.After(data.Session.TimeoutAt))
|
||||
|
||||
previousTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second
|
||||
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(previousTimeoutDuration), maxDelta)
|
||||
|
||||
refreshedTimeoutDuration := time.Duration(refreshedData.Session.TimeoutInSeconds) * time.Second
|
||||
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(refreshedTimeoutDuration), maxDelta)
|
||||
}
|
||||
|
||||
func TestSession(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = true
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Minute
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var data session.MetadataVerbose
|
||||
err := json.Unmarshal([]byte(resp.Body), &data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
allowedSkew := 5 * time.Second
|
||||
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
|
||||
|
||||
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
|
||||
// 1 second < time until session ends <= configured max session lifetime
|
||||
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
|
||||
assert.Greater(t, sessionEndDuration, time.Second)
|
||||
|
||||
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
|
||||
// 1 second < time until token expires <= max duration for tokens from IDP
|
||||
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
|
||||
assert.Greater(t, tokenExpiryDuration, time.Second)
|
||||
|
||||
assert.True(t, data.Session.Active)
|
||||
assert.True(t, data.Session.TimeoutAt.IsZero())
|
||||
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
|
||||
}
|
||||
|
||||
func TestSession_WithInactivity(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = true
|
||||
cfg.Session.Inactivity = true
|
||||
cfg.Session.InactivityTimeout = 10 * time.Minute
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var data session.MetadataVerbose
|
||||
err := json.Unmarshal([]byte(resp.Body), &data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
maxDelta := 5 * time.Second
|
||||
|
||||
assert.True(t, data.Session.Active)
|
||||
assert.False(t, data.Session.TimeoutAt.IsZero())
|
||||
|
||||
expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout)
|
||||
assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta)
|
||||
|
||||
actualTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second
|
||||
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(actualTimeoutDuration), maxDelta)
|
||||
}
|
||||
|
||||
func TestSession_WithRefresh(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = true
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Minute
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var data session.MetadataVerboseWithRefresh
|
||||
err := json.Unmarshal([]byte(resp.Body), &data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
allowedSkew := 5 * time.Second
|
||||
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
|
||||
|
||||
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
|
||||
// 1 second < time until session ends <= configured max session lifetime
|
||||
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
|
||||
assert.Greater(t, sessionEndDuration, time.Second)
|
||||
|
||||
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
|
||||
// 1 second < time until token expires <= max duration for tokens from IDP
|
||||
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
|
||||
assert.Greater(t, tokenExpiryDuration, time.Second)
|
||||
|
||||
// 1 second < next token refresh <= seconds until token expires
|
||||
assert.LessOrEqual(t, data.Tokens.NextAutoRefreshInSeconds, data.Tokens.ExpireInSeconds)
|
||||
assert.Greater(t, data.Tokens.NextAutoRefreshInSeconds, int64(1))
|
||||
|
||||
assert.True(t, data.Tokens.RefreshCooldown)
|
||||
// 1 second < refresh cooldown <= minimum refresh interval
|
||||
assert.LessOrEqual(t, data.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
|
||||
assert.Greater(t, data.Tokens.RefreshCooldownSeconds, int64(1))
|
||||
|
||||
assert.True(t, data.Session.Active)
|
||||
assert.True(t, data.Session.TimeoutAt.IsZero())
|
||||
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
|
||||
}
|
||||
|
||||
func localLogin(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
|
||||
// First, run /oauth2/login to set cookies
|
||||
loginURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/login")
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
errorhandler "github.com/nais/wonderwall/pkg/handler/error"
|
||||
logentry "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
openidclient "github.com/nais/wonderwall/pkg/openid/client"
|
||||
urlpkg "github.com/nais/wonderwall/pkg/url"
|
||||
)
|
||||
|
||||
const (
|
||||
CookieLifetime = 1 * time.Hour
|
||||
)
|
||||
|
||||
type LoginSource interface {
|
||||
GetClient() *openidclient.Client
|
||||
GetCookieOptsPathAware(r *http.Request) cookie.Options
|
||||
GetCrypter() crypto.Crypter
|
||||
GetErrorHandler() errorhandler.Handler
|
||||
GetRedirect() urlpkg.Redirect
|
||||
}
|
||||
|
||||
func Login(src LoginSource, w http.ResponseWriter, r *http.Request) {
|
||||
canonicalRedirect := src.GetRedirect().Canonical(r)
|
||||
login, err := src.GetClient().Login(r)
|
||||
if err != nil {
|
||||
if errors.Is(err, openidclient.ErrInvalidSecurityLevel) || errors.Is(err, openidclient.ErrInvalidLocale) {
|
||||
src.GetErrorHandler().BadRequest(w, r, err)
|
||||
} else {
|
||||
src.GetErrorHandler().InternalError(w, r, err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = setLoginCookies(src, w, r, login.Cookie(canonicalRedirect))
|
||||
if err != nil {
|
||||
src.GetErrorHandler().InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
fields := log.Fields{
|
||||
"redirect_after_login": canonicalRedirect,
|
||||
}
|
||||
logentry.LogEntryFrom(r).WithFields(fields).Info("login: redirecting to identity provider")
|
||||
http.Redirect(w, r, login.AuthCodeURL(), http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
type LoginSSOProxySource interface {
|
||||
GetSSOServerURL() *url.URL
|
||||
GetRedirect() urlpkg.Redirect
|
||||
}
|
||||
|
||||
func LoginSSOProxy(src LoginSSOProxySource, w http.ResponseWriter, r *http.Request) {
|
||||
logger := logentry.LogEntryFrom(r)
|
||||
|
||||
target := src.GetSSOServerURL()
|
||||
targetQuery := target.Query()
|
||||
|
||||
// override default query parameters
|
||||
reqQuery := r.URL.Query()
|
||||
if reqQuery.Has(openidclient.SecurityLevelURLParameter) {
|
||||
targetQuery.Set(openidclient.SecurityLevelURLParameter, reqQuery.Get(openidclient.SecurityLevelURLParameter))
|
||||
}
|
||||
if reqQuery.Has(openidclient.LocaleURLParameter) {
|
||||
targetQuery.Set(openidclient.LocaleURLParameter, reqQuery.Get(openidclient.LocaleURLParameter))
|
||||
}
|
||||
|
||||
target.RawQuery = reqQuery.Encode()
|
||||
|
||||
canonicalRedirect := src.GetRedirect().Canonical(r)
|
||||
ssoServerLoginURL := urlpkg.Login(target, canonicalRedirect)
|
||||
|
||||
logger.WithFields(log.Fields{
|
||||
"redirect_to": ssoServerLoginURL,
|
||||
"redirect_after_login": canonicalRedirect,
|
||||
}).Info("login: redirecting to sso server")
|
||||
|
||||
http.Redirect(w, r, ssoServerLoginURL, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func setLoginCookies(src LoginSource, w http.ResponseWriter, r *http.Request, loginCookie *openid.LoginCookie) error {
|
||||
loginCookieJson, err := json.Marshal(loginCookie)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshalling login cookie: %w", err)
|
||||
}
|
||||
|
||||
opts := src.GetCookieOptsPathAware(r).
|
||||
WithExpiresIn(CookieLifetime).
|
||||
WithSameSite(http.SameSiteNoneMode)
|
||||
value := string(loginCookieJson)
|
||||
|
||||
err = cookie.EncryptAndSet(w, cookie.Login, value, opts, src.GetCrypter())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set a duplicate cookie without the SameSite value set for user agents that do not properly handle SameSite
|
||||
err = cookie.EncryptAndSet(w, cookie.LoginLegacy, value, opts.WithSameSite(http.SameSiteDefaultMode), src.GetCrypter())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,125 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/sethvargo/go-retry"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
errorhandler "github.com/nais/wonderwall/pkg/handler/error"
|
||||
"github.com/nais/wonderwall/pkg/metrics"
|
||||
logentry "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
openidclient "github.com/nais/wonderwall/pkg/openid/client"
|
||||
retrypkg "github.com/nais/wonderwall/pkg/retry"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"github.com/nais/wonderwall/pkg/url"
|
||||
)
|
||||
|
||||
type LoginCallbackSource interface {
|
||||
GetClient() *openidclient.Client
|
||||
GetCookieOptions() cookie.Options
|
||||
GetCookieOptsPathAware(r *http.Request) cookie.Options
|
||||
GetCrypter() crypto.Crypter
|
||||
GetErrorHandler() errorhandler.Handler
|
||||
GetRedirect() url.Redirect
|
||||
GetSessions() *session.Handler
|
||||
GetSessionConfig() config.Session
|
||||
}
|
||||
|
||||
func LoginCallback(src LoginCallbackSource, w http.ResponseWriter, r *http.Request) {
|
||||
// unconditionally clear login cookie
|
||||
clearLoginCookies(src, w, r)
|
||||
|
||||
loginCookie, err := openid.GetLoginCookie(r, src.GetCrypter())
|
||||
if err != nil {
|
||||
msg := "callback: fetching login cookie"
|
||||
if errors.Is(err, http.ErrNoCookie) {
|
||||
msg += ": fallback cookie not found (user might have blocked all cookies, or the callback route was accessed before the login route)"
|
||||
}
|
||||
src.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("%s: %w", msg, err))
|
||||
return
|
||||
}
|
||||
|
||||
loginCallback, err := src.GetClient().LoginCallback(r, loginCookie)
|
||||
if err != nil {
|
||||
src.GetErrorHandler().InternalError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := loginCallback.IdentityProviderError(); err != nil {
|
||||
src.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := loginCallback.StateMismatchError(); err != nil {
|
||||
src.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("callback: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
tokens, err := redeemValidTokens(r, loginCallback)
|
||||
if err != nil {
|
||||
src.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: redeeming tokens: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
sessionLifetime := src.GetSessionConfig().MaxLifetime
|
||||
|
||||
ticket, err := src.GetSessions().Create(r, tokens, sessionLifetime)
|
||||
if err != nil {
|
||||
src.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
opts := src.GetCookieOptsPathAware(r).
|
||||
WithExpiresIn(sessionLifetime)
|
||||
err = ticket.Set(w, opts, src.GetCrypter())
|
||||
if err != nil {
|
||||
src.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: setting session cookie: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
redirect := src.GetRedirect().Clean(r, loginCookie.Referer)
|
||||
|
||||
logSuccessfulLogin(r, tokens, redirect)
|
||||
cookie.Clear(w, cookie.Retry, src.GetCookieOptsPathAware(r))
|
||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func clearLoginCookies(src LogoutCallbackSource, w http.ResponseWriter, r *http.Request) {
|
||||
opts := src.GetCookieOptsPathAware(r)
|
||||
cookie.Clear(w, cookie.Login, opts.WithSameSite(http.SameSiteNoneMode))
|
||||
cookie.Clear(w, cookie.LoginLegacy, opts.WithSameSite(http.SameSiteDefaultMode))
|
||||
}
|
||||
|
||||
func redeemValidTokens(r *http.Request, loginCallback *openidclient.LoginCallback) (*openid.Tokens, error) {
|
||||
var tokens *openid.Tokens
|
||||
var err error
|
||||
|
||||
retryable := func(ctx context.Context) error {
|
||||
tokens, err = loginCallback.RedeemTokens(ctx)
|
||||
return retry.RetryableError(err)
|
||||
}
|
||||
|
||||
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func logSuccessfulLogin(r *http.Request, tokens *openid.Tokens, referer string) {
|
||||
fields := log.Fields{
|
||||
"redirect_to": referer,
|
||||
"jti": tokens.IDToken.GetJwtID(),
|
||||
}
|
||||
|
||||
logentry.LogEntryFrom(r).WithFields(fields).Info("callback: successful login")
|
||||
metrics.ObserveLogin()
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
)
|
||||
|
||||
func TestCallback(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
}
|
||||
|
||||
func TestCallback_SessionStateRequired(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.OpenIDConfig.TestProvider.WithCheckSessionIFrameSupport(idp.ProviderServer.URL + "/checksession")
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
|
||||
resp := authorize(t, rpClient, idp)
|
||||
|
||||
// Get callback URL after successful auth
|
||||
params := resp.Location.Query()
|
||||
sessionState := params.Get("session_state")
|
||||
assert.NotEmpty(t, sessionState)
|
||||
|
||||
callback(t, rpClient, resp)
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
urlpkg "github.com/nais/wonderwall/pkg/url"
|
||||
)
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
|
||||
resp := localLogin(t, rpClient, idp)
|
||||
loginURL := resp.Location
|
||||
|
||||
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/login")
|
||||
|
||||
expectedCallbackURL, err := urlpkg.LoginCallback(req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host))
|
||||
assert.Equal(t, "/authorize", loginURL.Path)
|
||||
assert.Equal(t, idp.OpenIDConfig.Client().ACRValues(), loginURL.Query().Get("acr_values"))
|
||||
assert.Equal(t, idp.OpenIDConfig.Client().UILocales(), loginURL.Query().Get("ui_locales"))
|
||||
assert.Equal(t, idp.OpenIDConfig.Client().ClientID(), loginURL.Query().Get("client_id"))
|
||||
assert.Equal(t, expectedCallbackURL, loginURL.Query().Get("redirect_uri"))
|
||||
assert.Equal(t, "S256", loginURL.Query().Get("code_challenge_method"))
|
||||
assert.ElementsMatch(t, idp.OpenIDConfig.Client().Scopes(), strings.Split(loginURL.Query().Get("scope"), " "))
|
||||
assert.NotEmpty(t, loginURL.Query().Get("state"))
|
||||
assert.NotEmpty(t, loginURL.Query().Get("nonce"))
|
||||
assert.NotEmpty(t, loginURL.Query().Get("code_challenge"))
|
||||
|
||||
resp = get(t, rpClient, loginURL.String())
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
callbackURL := resp.Location
|
||||
assert.Equal(t, loginURL.Query().Get("state"), callbackURL.Query().Get("state"))
|
||||
assert.NotEmpty(t, callbackURL.Query().Get("code"))
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
errorhandler "github.com/nais/wonderwall/pkg/handler/error"
|
||||
"github.com/nais/wonderwall/pkg/metrics"
|
||||
logentry "github.com/nais/wonderwall/pkg/middleware"
|
||||
openidclient "github.com/nais/wonderwall/pkg/openid/client"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
type LogoutSource interface {
|
||||
GetClient() *openidclient.Client
|
||||
GetCookieOptions() cookie.Options
|
||||
GetCookieOptsPathAware(r *http.Request) cookie.Options
|
||||
GetErrorHandler() errorhandler.Handler
|
||||
GetSessions() *session.Handler
|
||||
}
|
||||
|
||||
type LogoutOptions struct {
|
||||
GlobalLogout bool
|
||||
}
|
||||
|
||||
func Logout(src LogoutSource, w http.ResponseWriter, r *http.Request, opts LogoutOptions) {
|
||||
logger := logentry.LogEntryFrom(r)
|
||||
logout, err := src.GetClient().Logout(r)
|
||||
if err != nil {
|
||||
src.GetErrorHandler().InternalError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
var idToken string
|
||||
|
||||
sessions := src.GetSessions()
|
||||
|
||||
ticket, err := sessions.GetTicket(r)
|
||||
if err == nil {
|
||||
sessionData, err := sessions.Get(r, ticket)
|
||||
if err == nil && sessionData != nil {
|
||||
idToken = sessionData.IDToken
|
||||
|
||||
err = sessions.Destroy(r, ticket.Key())
|
||||
if err != nil && !errors.Is(err, session.ErrKeyNotFound) {
|
||||
src.GetErrorHandler().InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
logger.WithField("jti", sessionData.IDTokenJwtID).
|
||||
Info("logout: successful local logout")
|
||||
metrics.ObserveLogout(metrics.LogoutOperationLocal)
|
||||
}
|
||||
}
|
||||
|
||||
cookie.Clear(w, cookie.Session, src.GetCookieOptsPathAware(r))
|
||||
|
||||
if opts.GlobalLogout {
|
||||
logger.Debug("logout: redirecting to identity provider for global/single-logout")
|
||||
metrics.ObserveLogout(metrics.LogoutOperationSelfInitiated)
|
||||
http.Redirect(w, r, logout.SingleLogoutURL(idToken), http.StatusTemporaryRedirect)
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
logentry "github.com/nais/wonderwall/pkg/middleware"
|
||||
openidclient "github.com/nais/wonderwall/pkg/openid/client"
|
||||
)
|
||||
|
||||
type LogoutCallbackSource interface {
|
||||
GetClient() *openidclient.Client
|
||||
GetCookieOptsPathAware(r *http.Request) cookie.Options
|
||||
}
|
||||
|
||||
func LogoutCallback(src LogoutCallbackSource, w http.ResponseWriter, r *http.Request) {
|
||||
redirect := src.GetClient().LogoutCallback(r).PostLogoutRedirectURI()
|
||||
|
||||
cookie.Clear(w, cookie.Retry, src.GetCookieOptsPathAware(r))
|
||||
logentry.LogEntryFrom(r).Debugf("logout/callback: redirecting to %s", redirect)
|
||||
http.Redirect(w, r, redirect, http.StatusTemporaryRedirect)
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
)
|
||||
|
||||
func TestLogoutCallback(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
logout(t, rpClient, idp)
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/metrics"
|
||||
mw "github.com/nais/wonderwall/pkg/middleware"
|
||||
openidclient "github.com/nais/wonderwall/pkg/openid/client"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
type LogoutFrontChannelSource interface {
|
||||
GetClient() *openidclient.Client
|
||||
GetCookieOptions() cookie.Options
|
||||
GetCookieOptsPathAware(r *http.Request) cookie.Options
|
||||
GetSessions() *session.Handler
|
||||
}
|
||||
|
||||
func LogoutFrontChannel(src LogoutFrontChannelSource, w http.ResponseWriter, r *http.Request) {
|
||||
logger := mw.LogEntryFrom(r)
|
||||
|
||||
// Unconditionally destroy all local references to the session.
|
||||
cookie.Clear(w, cookie.Session, src.GetCookieOptsPathAware(r))
|
||||
|
||||
sessions := src.GetSessions()
|
||||
client := src.GetClient()
|
||||
key, err := getSessionKey(r, sessions, client)
|
||||
if err != nil {
|
||||
logger.Debugf("front-channel logout: getting session key: %+v; ignoring", err)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
err = sessions.Destroy(r, key)
|
||||
if err != nil {
|
||||
logger.Warnf("front-channel logout: destroying session: %+v", err)
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
return
|
||||
}
|
||||
|
||||
cookie.Clear(w, cookie.Retry, src.GetCookieOptsPathAware(r))
|
||||
metrics.ObserveLogout(metrics.LogoutOperationFrontChannel)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func getSessionKey(r *http.Request, sessions *session.Handler, client *openidclient.Client) (string, error) {
|
||||
logoutFrontchannel := client.LogoutFrontchannel(r)
|
||||
|
||||
if logoutFrontchannel.MissingSidParameter() {
|
||||
ticket, err := sessions.GetTicket(r)
|
||||
if err != nil {
|
||||
return ticket.Key(), nil
|
||||
}
|
||||
return "", fmt.Errorf("neither sid parameter nor session ticket found in request: %w", err)
|
||||
}
|
||||
|
||||
sid := logoutFrontchannel.Sid()
|
||||
return sessions.Key(sid), nil
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
func TestFrontChannelLogout(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.OpenIDConfig.TestProvider.WithFrontChannelLogoutSupport()
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
sessionCookie := login(t, rpClient, idp)
|
||||
|
||||
// Trigger front-channel logout
|
||||
sid := func(r *http.Request) string {
|
||||
r.AddCookie(sessionCookie)
|
||||
|
||||
ticket, err := session.GetTicket(r, idp.RelyingPartyHandler.GetCrypter())
|
||||
assert.NoError(t, err)
|
||||
|
||||
data, err := idp.RelyingPartyHandler.GetSessions().Get(r, ticket)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return data.ExternalSessionID
|
||||
}
|
||||
|
||||
frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/frontchannel")
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := idp.GetRequest(frontchannelLogoutURL.String())
|
||||
|
||||
values := url.Values{}
|
||||
values.Add("sid", sid(req))
|
||||
values.Add("iss", idp.OpenIDConfig.Provider().Issuer())
|
||||
frontchannelLogoutURL.RawQuery = values.Encode()
|
||||
|
||||
resp := get(t, rpClient, frontchannelLogoutURL.String())
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
urlpkg "github.com/nais/wonderwall/pkg/url"
|
||||
)
|
||||
|
||||
func TestLogout(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := selfInitiatedLogout(t, rpClient, idp)
|
||||
|
||||
// Get endsession endpoint after local logout
|
||||
endsessionURL := resp.Location
|
||||
|
||||
idpserverURL, err := url.Parse(idp.ProviderServer.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout/callback")
|
||||
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
endsessionParams := endsessionURL.Query()
|
||||
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
|
||||
assert.Equal(t, "/endsession", endsessionURL.Path)
|
||||
assert.Equal(t, []string{expectedLogoutCallbackURL}, endsessionParams["post_logout_redirect_uri"])
|
||||
assert.NotEmpty(t, endsessionParams["id_token_hint"])
|
||||
}
|
||||
|
||||
func TestLogoutLocal(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
localLogout(t, rpClient, idp)
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
mw "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
type SessionSource interface {
|
||||
GetSessions() *session.Handler
|
||||
GetSessionConfig() config.Session
|
||||
}
|
||||
|
||||
func Session(src SessionSource, w http.ResponseWriter, r *http.Request) {
|
||||
logger := mw.LogEntryFrom(r)
|
||||
|
||||
ticket, err := src.GetSessions().GetTicket(r)
|
||||
if err != nil {
|
||||
logger.Infof("session/refresh: getting ticket: %+v", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := src.GetSessions().Get(r, ticket)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, session.ErrInvalidSession), errors.Is(err, session.ErrKeyNotFound):
|
||||
logger.Infof("session/info: getting session: %+v", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
default:
|
||||
logger.Warnf("session/info: getting session: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if src.GetSessionConfig().Refresh {
|
||||
err = json.NewEncoder(w).Encode(data.Metadata.VerboseWithRefresh())
|
||||
} else {
|
||||
err = json.NewEncoder(w).Encode(data.Metadata.Verbose())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Warnf("session/info: marshalling metadata: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
mw "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
type SessionRefreshSource interface {
|
||||
GetSessions() *session.Handler
|
||||
}
|
||||
|
||||
func SessionRefresh(src SessionRefreshSource, w http.ResponseWriter, r *http.Request) {
|
||||
logger := mw.LogEntryFrom(r)
|
||||
|
||||
ticket, err := src.GetSessions().GetTicket(r)
|
||||
if err != nil {
|
||||
logger.Infof("session/refresh: getting ticket: %+v", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
data, err := src.GetSessions().Get(r, ticket)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, session.ErrInvalidSession), errors.Is(err, session.ErrKeyNotFound):
|
||||
logger.Infof("session/refresh: getting session: %+v", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
default:
|
||||
logger.Warnf("session/refresh: getting session: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
data, err = src.GetSessions().Refresh(r, ticket, data)
|
||||
if err != nil {
|
||||
if errors.Is(err, session.ErrInvalidIdpState) || errors.Is(err, session.ErrInvalidSession) {
|
||||
logger.Infof("session/refresh: refreshing: %+v", err)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
logger.Warnf("session/refresh: refreshing: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w).Encode(data.Metadata.VerboseWithRefresh())
|
||||
if err != nil {
|
||||
logger.Warnf("session/refresh: marshalling metadata: %+v", err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
func TestSessionRefresh(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = true
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Second
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
// get initial session info
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var data session.MetadataVerboseWithRefresh
|
||||
err := json.Unmarshal([]byte(resp.Body), &data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// wait until refresh cooldown has reached zero before refresh
|
||||
waitForRefreshCooldownTimer(t, idp, rpClient)
|
||||
|
||||
resp = sessionRefresh(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var refreshedData session.MetadataVerboseWithRefresh
|
||||
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// session create and end times should be unchanged
|
||||
assert.WithinDuration(t, data.Session.CreatedAt, refreshedData.Session.CreatedAt, 0)
|
||||
assert.WithinDuration(t, data.Session.EndsAt, refreshedData.Session.EndsAt, 0)
|
||||
|
||||
// token expiration and refresh times should be later than before
|
||||
assert.True(t, refreshedData.Tokens.ExpireAt.After(data.Tokens.ExpireAt))
|
||||
assert.True(t, refreshedData.Tokens.RefreshedAt.After(data.Tokens.RefreshedAt))
|
||||
|
||||
allowedSkew := 5 * time.Second
|
||||
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), refreshedData.Tokens.ExpireAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now(), refreshedData.Tokens.RefreshedAt, allowedSkew)
|
||||
|
||||
sessionEndDuration := time.Duration(refreshedData.Session.EndsInSeconds) * time.Second
|
||||
// 1 second < time until session ends <= configured max session lifetime
|
||||
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
|
||||
assert.Greater(t, sessionEndDuration, time.Second)
|
||||
|
||||
tokenExpiryDuration := time.Duration(refreshedData.Tokens.ExpireInSeconds) * time.Second
|
||||
// 1 second < time until token expires <= max duration for tokens from IDP
|
||||
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
|
||||
assert.Greater(t, tokenExpiryDuration, time.Second)
|
||||
|
||||
// 1 second < next token refresh <= seconds until token expires
|
||||
assert.LessOrEqual(t, refreshedData.Tokens.NextAutoRefreshInSeconds, refreshedData.Tokens.ExpireInSeconds)
|
||||
assert.Greater(t, refreshedData.Tokens.NextAutoRefreshInSeconds, int64(1))
|
||||
|
||||
assert.True(t, refreshedData.Tokens.RefreshCooldown)
|
||||
// 1 second < refresh cooldown <= minimum refresh interval
|
||||
assert.LessOrEqual(t, refreshedData.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
|
||||
assert.Greater(t, refreshedData.Tokens.RefreshCooldownSeconds, int64(1))
|
||||
|
||||
assert.True(t, data.Session.Active)
|
||||
assert.True(t, refreshedData.Session.Active)
|
||||
|
||||
assert.True(t, data.Session.TimeoutAt.IsZero())
|
||||
assert.True(t, refreshedData.Session.TimeoutAt.IsZero())
|
||||
|
||||
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
|
||||
assert.Equal(t, int64(-1), refreshedData.Session.TimeoutInSeconds)
|
||||
}
|
||||
|
||||
func TestSessionRefresh_Disabled(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = false
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Second
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := sessionRefresh(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
func TestSessionRefresh_WithInactivity(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = true
|
||||
cfg.Session.Inactivity = true
|
||||
cfg.Session.InactivityTimeout = 10 * time.Minute
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Second
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
// get initial session info
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var data session.MetadataVerboseWithRefresh
|
||||
err := json.Unmarshal([]byte(resp.Body), &data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// wait until refresh cooldown has reached zero before refresh
|
||||
waitForRefreshCooldownTimer(t, idp, rpClient)
|
||||
|
||||
resp = sessionRefresh(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var refreshedData session.MetadataVerboseWithRefresh
|
||||
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
|
||||
assert.NoError(t, err)
|
||||
|
||||
maxDelta := 5 * time.Second
|
||||
|
||||
assert.True(t, data.Session.Active)
|
||||
assert.True(t, refreshedData.Session.Active)
|
||||
|
||||
assert.False(t, data.Session.TimeoutAt.IsZero())
|
||||
assert.False(t, refreshedData.Session.TimeoutAt.IsZero())
|
||||
|
||||
expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout)
|
||||
assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta)
|
||||
assert.WithinDuration(t, expectedTimeoutAt, refreshedData.Session.TimeoutAt, maxDelta)
|
||||
|
||||
assert.True(t, refreshedData.Session.TimeoutAt.After(data.Session.TimeoutAt))
|
||||
|
||||
previousTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second
|
||||
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(previousTimeoutDuration), maxDelta)
|
||||
|
||||
refreshedTimeoutDuration := time.Duration(refreshedData.Session.TimeoutInSeconds) * time.Second
|
||||
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(refreshedTimeoutDuration), maxDelta)
|
||||
}
|
||||
@@ -1,131 +0,0 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
func TestSession(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = true
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Minute
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var data session.MetadataVerbose
|
||||
err := json.Unmarshal([]byte(resp.Body), &data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
allowedSkew := 5 * time.Second
|
||||
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
|
||||
|
||||
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
|
||||
// 1 second < time until session ends <= configured max session lifetime
|
||||
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
|
||||
assert.Greater(t, sessionEndDuration, time.Second)
|
||||
|
||||
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
|
||||
// 1 second < time until token expires <= max duration for tokens from IDP
|
||||
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
|
||||
assert.Greater(t, tokenExpiryDuration, time.Second)
|
||||
|
||||
assert.True(t, data.Session.Active)
|
||||
assert.True(t, data.Session.TimeoutAt.IsZero())
|
||||
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
|
||||
}
|
||||
|
||||
func TestSession_WithInactivity(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = true
|
||||
cfg.Session.Inactivity = true
|
||||
cfg.Session.InactivityTimeout = 10 * time.Minute
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var data session.MetadataVerbose
|
||||
err := json.Unmarshal([]byte(resp.Body), &data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
maxDelta := 5 * time.Second
|
||||
|
||||
assert.True(t, data.Session.Active)
|
||||
assert.False(t, data.Session.TimeoutAt.IsZero())
|
||||
|
||||
expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout)
|
||||
assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta)
|
||||
|
||||
actualTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second
|
||||
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(actualTimeoutDuration), maxDelta)
|
||||
}
|
||||
|
||||
func TestSession_WithRefresh(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.Session.Refresh = true
|
||||
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
idp.ProviderHandler.TokenDuration = 5 * time.Minute
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := sessionInfo(t, idp, rpClient)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
var data session.MetadataVerboseWithRefresh
|
||||
err := json.Unmarshal([]byte(resp.Body), &data)
|
||||
assert.NoError(t, err)
|
||||
|
||||
allowedSkew := 5 * time.Second
|
||||
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
|
||||
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
|
||||
|
||||
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
|
||||
// 1 second < time until session ends <= configured max session lifetime
|
||||
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
|
||||
assert.Greater(t, sessionEndDuration, time.Second)
|
||||
|
||||
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
|
||||
// 1 second < time until token expires <= max duration for tokens from IDP
|
||||
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
|
||||
assert.Greater(t, tokenExpiryDuration, time.Second)
|
||||
|
||||
// 1 second < next token refresh <= seconds until token expires
|
||||
assert.LessOrEqual(t, data.Tokens.NextAutoRefreshInSeconds, data.Tokens.ExpireInSeconds)
|
||||
assert.Greater(t, data.Tokens.NextAutoRefreshInSeconds, int64(1))
|
||||
|
||||
assert.True(t, data.Tokens.RefreshCooldown)
|
||||
// 1 second < refresh cooldown <= minimum refresh interval
|
||||
assert.LessOrEqual(t, data.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
|
||||
assert.Greater(t, data.Tokens.RefreshCooldownSeconds, int64(1))
|
||||
|
||||
assert.True(t, data.Session.Active)
|
||||
assert.True(t, data.Session.TimeoutAt.IsZero())
|
||||
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
|
||||
}
|
||||
@@ -3,12 +3,15 @@ package client
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/openid/config"
|
||||
"github.com/nais/wonderwall/pkg/strings"
|
||||
@@ -52,11 +55,11 @@ func NewLogin(c *Client, r *http.Request) (*Login, error) {
|
||||
return nil, fmt.Errorf("generating auth code url: %w", err)
|
||||
}
|
||||
|
||||
cookie := params.cookie(callbackURL)
|
||||
loginCookie := params.cookie(callbackURL)
|
||||
|
||||
return &Login{
|
||||
authCodeURL: url,
|
||||
cookie: cookie,
|
||||
cookie: loginCookie,
|
||||
params: params,
|
||||
}, nil
|
||||
}
|
||||
@@ -79,11 +82,6 @@ func (l *Login) CodeVerifier() string {
|
||||
return l.params.CodeVerifier
|
||||
}
|
||||
|
||||
func (l *Login) Cookie(canonicalRedirect string) *openid.LoginCookie {
|
||||
l.cookie.Referer = canonicalRedirect
|
||||
return l.cookie
|
||||
}
|
||||
|
||||
func (l *Login) Nonce() string {
|
||||
return l.params.Nonce
|
||||
}
|
||||
@@ -92,6 +90,30 @@ func (l *Login) State() string {
|
||||
return l.params.State
|
||||
}
|
||||
|
||||
func (l *Login) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter, canonicalRedirect string) error {
|
||||
l.cookie.Referer = canonicalRedirect
|
||||
|
||||
loginCookieJson, err := json.Marshal(l.cookie)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshalling login cookie: %w", err)
|
||||
}
|
||||
|
||||
value := string(loginCookieJson)
|
||||
|
||||
err = cookie.EncryptAndSet(w, cookie.Login, value, opts, crypter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// set a duplicate cookie without the SameSite value set for user agents that do not properly handle SameSite
|
||||
err = cookie.EncryptAndSet(w, cookie.LoginLegacy, value, opts.WithSameSite(http.SameSiteDefaultMode), crypter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type loginParameters struct {
|
||||
*Client
|
||||
CodeVerifier string
|
||||
|
||||
@@ -103,6 +103,8 @@ func TestLogin_URL(t *testing.T) {
|
||||
assert.ElementsMatch(t, query["code_challenge"], []string{result.CodeChallenge()})
|
||||
assert.ElementsMatch(t, query["code_challenge_method"], []string{"S256"})
|
||||
|
||||
assert.Equal(t, client.CodeChallenge(result.CodeVerifier()), result.CodeChallenge())
|
||||
|
||||
if test.extraParams != nil {
|
||||
for key, value := range test.extraParams {
|
||||
assert.Contains(t, query, key)
|
||||
|
||||
Reference in New Issue
Block a user