mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-06 16:36:51 +00:00
If autologin is enabled, check for headers that indicate that the request is a navigation request
and respond appropriately.
A navigation request is assumed to match all of the following:
- uses the GET HTTP method
- either:
- a) sends the fetch metadata headers, specifically
`Sec-Fetch-Mode=navigate` and `Sec-Fetch-Dest=document`, or (if
unsupported by the browser)
- b) sends the `Accept` header with a value that contains
`text/html` (which most browsers do by default for navigation
requests, the exception being IE8 AFAIK)
Non-navigation requests (e.g. fetch / xhr / ajax requests) will receive a
401 Unauthorized, with the Location header set to the login endpoint.
The redirect parameter is also set to point back to the URL found in the
Referer header (though with the scheme and host removed to only allow
redirects relative to the origin host.)
With this fix, autologin will also intercept requests other than GET.
This is to improve the security posture of upstreams that assume that autologin
enforces authentication for all methods.
Fixes #156.
816 lines
25 KiB
Go
816 lines
25 KiB
Go
package handler_test
|
|
|
|
import (
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/lestrrat-go/jwx/v2/jwt"
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/nais/wonderwall/pkg/cookie"
|
|
"github.com/nais/wonderwall/pkg/mock"
|
|
"github.com/nais/wonderwall/pkg/session"
|
|
urlpkg "github.com/nais/wonderwall/pkg/url"
|
|
)
|
|
|
|
func TestLogin(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
|
|
resp := localLogin(t, rpClient, idp)
|
|
loginURL := resp.Location
|
|
|
|
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/login")
|
|
|
|
expectedCallbackURL, err := urlpkg.LoginCallback(req)
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host))
|
|
assert.Equal(t, "/authorize", loginURL.Path)
|
|
assert.Equal(t, idp.OpenIDConfig.Client().ACRValues(), loginURL.Query().Get("acr_values"))
|
|
assert.Equal(t, idp.OpenIDConfig.Client().UILocales(), loginURL.Query().Get("ui_locales"))
|
|
assert.Equal(t, idp.OpenIDConfig.Client().ClientID(), loginURL.Query().Get("client_id"))
|
|
assert.Equal(t, expectedCallbackURL, loginURL.Query().Get("redirect_uri"))
|
|
assert.Equal(t, "S256", loginURL.Query().Get("code_challenge_method"))
|
|
assert.ElementsMatch(t, idp.OpenIDConfig.Client().Scopes(), strings.Split(loginURL.Query().Get("scope"), " "))
|
|
assert.NotEmpty(t, loginURL.Query().Get("state"))
|
|
assert.NotEmpty(t, loginURL.Query().Get("nonce"))
|
|
assert.NotEmpty(t, loginURL.Query().Get("code_challenge"))
|
|
|
|
resp = get(t, rpClient, loginURL.String())
|
|
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
|
|
|
callbackURL := resp.Location
|
|
assert.Equal(t, loginURL.Query().Get("state"), callbackURL.Query().Get("state"))
|
|
assert.NotEmpty(t, callbackURL.Query().Get("code"))
|
|
}
|
|
|
|
func TestCallback(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
login(t, rpClient, idp)
|
|
}
|
|
|
|
func TestCallback_SessionStateRequired(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
idp.OpenIDConfig.TestProvider.WithCheckSessionIFrameSupport(idp.ProviderServer.URL + "/checksession")
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
|
|
resp := authorize(t, rpClient, idp)
|
|
|
|
// Get callback URL after successful auth
|
|
params := resp.Location.Query()
|
|
sessionState := params.Get("session_state")
|
|
assert.NotEmpty(t, sessionState)
|
|
|
|
callback(t, rpClient, resp)
|
|
}
|
|
|
|
func TestLogout(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
login(t, rpClient, idp)
|
|
selfInitiatedLogout(t, rpClient, idp)
|
|
}
|
|
|
|
func TestLogoutLocal(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
login(t, rpClient, idp)
|
|
|
|
localLogout(t, rpClient, idp)
|
|
}
|
|
|
|
func TestLogoutCallback(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
login(t, rpClient, idp)
|
|
logout(t, rpClient, idp)
|
|
}
|
|
|
|
func TestLogoutCallback_WithRedirect(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
redirect := idp.RelyingPartyServer.URL + "/api/me"
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
login(t, rpClient, idp)
|
|
logout(t, rpClient, idp, redirect)
|
|
}
|
|
|
|
func TestFrontChannelLogout(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
idp.OpenIDConfig.TestProvider.WithFrontChannelLogoutSupport()
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
sessionCookie := login(t, rpClient, idp)
|
|
|
|
// Trigger front-channel logout
|
|
sid := func(r *http.Request) string {
|
|
r.AddCookie(sessionCookie)
|
|
|
|
data, err := idp.RelyingPartyHandler.SessionManager.Get(r)
|
|
assert.NoError(t, err)
|
|
|
|
return data.ExternalSessionID()
|
|
}
|
|
|
|
frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/frontchannel")
|
|
assert.NoError(t, err)
|
|
|
|
req := idp.GetRequest(frontchannelLogoutURL.String())
|
|
|
|
values := url.Values{}
|
|
values.Add("sid", sid(req))
|
|
values.Add("iss", idp.OpenIDConfig.Provider().Issuer())
|
|
frontchannelLogoutURL.RawQuery = values.Encode()
|
|
|
|
resp := get(t, rpClient, frontchannelLogoutURL.String())
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
}
|
|
|
|
func TestSessionRefresh(t *testing.T) {
|
|
cfg := mock.Config()
|
|
cfg.Session.Refresh = true
|
|
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
idp.ProviderHandler.TokenDuration = 5 * time.Second
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
login(t, rpClient, idp)
|
|
|
|
// get initial session info
|
|
resp := sessionInfo(t, idp, rpClient)
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
var data session.MetadataVerboseWithRefresh
|
|
err := json.Unmarshal([]byte(resp.Body), &data)
|
|
assert.NoError(t, err)
|
|
|
|
// wait until refresh cooldown has reached zero before refresh
|
|
waitForRefreshCooldownTimer(t, idp, rpClient)
|
|
|
|
resp = sessionRefresh(t, idp, rpClient)
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
var refreshedData session.MetadataVerboseWithRefresh
|
|
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
|
|
assert.NoError(t, err)
|
|
|
|
// session create and end times should be unchanged
|
|
assert.WithinDuration(t, data.Session.CreatedAt, refreshedData.Session.CreatedAt, 0)
|
|
assert.WithinDuration(t, data.Session.EndsAt, refreshedData.Session.EndsAt, 0)
|
|
|
|
// token expiration and refresh times should be later than before
|
|
assert.True(t, refreshedData.Tokens.ExpireAt.After(data.Tokens.ExpireAt))
|
|
assert.True(t, refreshedData.Tokens.RefreshedAt.After(data.Tokens.RefreshedAt))
|
|
|
|
allowedSkew := 5 * time.Second
|
|
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), refreshedData.Tokens.ExpireAt, allowedSkew)
|
|
assert.WithinDuration(t, time.Now(), refreshedData.Tokens.RefreshedAt, allowedSkew)
|
|
|
|
sessionEndDuration := time.Duration(refreshedData.Session.EndsInSeconds) * time.Second
|
|
// 1 second < time until session ends <= configured max session lifetime
|
|
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
|
|
assert.Greater(t, sessionEndDuration, time.Second)
|
|
|
|
tokenExpiryDuration := time.Duration(refreshedData.Tokens.ExpireInSeconds) * time.Second
|
|
// 1 second < time until token expires <= max duration for tokens from IDP
|
|
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
|
|
assert.Greater(t, tokenExpiryDuration, time.Second)
|
|
|
|
// auto refresh is not enabled
|
|
assert.Less(t, refreshedData.Tokens.NextAutoRefreshInSeconds, refreshedData.Tokens.ExpireInSeconds)
|
|
assert.Equal(t, refreshedData.Tokens.NextAutoRefreshInSeconds, int64(-1))
|
|
|
|
assert.True(t, refreshedData.Tokens.RefreshCooldown)
|
|
// 1 second < refresh cooldown <= minimum refresh interval
|
|
assert.LessOrEqual(t, refreshedData.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
|
|
assert.Greater(t, refreshedData.Tokens.RefreshCooldownSeconds, int64(1))
|
|
|
|
assert.True(t, data.Session.Active)
|
|
assert.True(t, refreshedData.Session.Active)
|
|
|
|
assert.True(t, data.Session.TimeoutAt.IsZero())
|
|
assert.True(t, refreshedData.Session.TimeoutAt.IsZero())
|
|
|
|
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
|
|
assert.Equal(t, int64(-1), refreshedData.Session.TimeoutInSeconds)
|
|
}
|
|
|
|
func TestSessionRefresh_Disabled(t *testing.T) {
|
|
cfg := mock.Config()
|
|
cfg.Session.Refresh = false
|
|
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
idp.ProviderHandler.TokenDuration = 5 * time.Second
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
login(t, rpClient, idp)
|
|
|
|
resp := sessionRefresh(t, idp, rpClient)
|
|
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
|
}
|
|
|
|
func TestSessionRefresh_WithInactivity(t *testing.T) {
|
|
cfg := mock.Config()
|
|
cfg.Session.Refresh = true
|
|
cfg.Session.Inactivity = true
|
|
cfg.Session.InactivityTimeout = 10 * time.Minute
|
|
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
idp.ProviderHandler.TokenDuration = 5 * time.Second
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
login(t, rpClient, idp)
|
|
|
|
// get initial session info
|
|
resp := sessionInfo(t, idp, rpClient)
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
var data session.MetadataVerboseWithRefresh
|
|
err := json.Unmarshal([]byte(resp.Body), &data)
|
|
assert.NoError(t, err)
|
|
|
|
// wait until refresh cooldown has reached zero before refresh
|
|
waitForRefreshCooldownTimer(t, idp, rpClient)
|
|
|
|
resp = sessionRefresh(t, idp, rpClient)
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
var refreshedData session.MetadataVerboseWithRefresh
|
|
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
|
|
assert.NoError(t, err)
|
|
|
|
maxDelta := 5 * time.Second
|
|
|
|
assert.True(t, data.Session.Active)
|
|
assert.True(t, refreshedData.Session.Active)
|
|
|
|
assert.False(t, data.Session.TimeoutAt.IsZero())
|
|
assert.False(t, refreshedData.Session.TimeoutAt.IsZero())
|
|
|
|
expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout)
|
|
assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta)
|
|
assert.WithinDuration(t, expectedTimeoutAt, refreshedData.Session.TimeoutAt, maxDelta)
|
|
|
|
assert.True(t, refreshedData.Session.TimeoutAt.After(data.Session.TimeoutAt))
|
|
|
|
previousTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second
|
|
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(previousTimeoutDuration), maxDelta)
|
|
|
|
refreshedTimeoutDuration := time.Duration(refreshedData.Session.TimeoutInSeconds) * time.Second
|
|
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(refreshedTimeoutDuration), maxDelta)
|
|
}
|
|
|
|
func TestSessionRefresh_WithRefreshAuto(t *testing.T) {
|
|
cfg := mock.Config()
|
|
cfg.Session.Refresh = true
|
|
cfg.Session.RefreshAuto = true
|
|
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
idp.ProviderHandler.TokenDuration = 5 * time.Second
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
login(t, rpClient, idp)
|
|
|
|
// get initial session info
|
|
resp := sessionInfo(t, idp, rpClient)
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
var data session.MetadataVerboseWithRefresh
|
|
err := json.Unmarshal([]byte(resp.Body), &data)
|
|
assert.NoError(t, err)
|
|
|
|
// wait until refresh cooldown has reached zero before refresh
|
|
waitForRefreshCooldownTimer(t, idp, rpClient)
|
|
|
|
resp = sessionRefresh(t, idp, rpClient)
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
var refreshedData session.MetadataVerboseWithRefresh
|
|
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
|
|
assert.NoError(t, err)
|
|
|
|
// 1 second < next token refresh <= seconds until token expires
|
|
assert.LessOrEqual(t, refreshedData.Tokens.NextAutoRefreshInSeconds, refreshedData.Tokens.ExpireInSeconds)
|
|
assert.Greater(t, refreshedData.Tokens.NextAutoRefreshInSeconds, int64(1))
|
|
}
|
|
|
|
func TestSession(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
idp.ProviderHandler.TokenDuration = 5 * time.Minute
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
login(t, rpClient, idp)
|
|
|
|
resp := sessionInfo(t, idp, rpClient)
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
var data session.MetadataVerbose
|
|
err := json.Unmarshal([]byte(resp.Body), &data)
|
|
assert.NoError(t, err)
|
|
|
|
allowedSkew := 5 * time.Second
|
|
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
|
|
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
|
|
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
|
|
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
|
|
|
|
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
|
|
// 1 second < time until session ends <= configured max session lifetime
|
|
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
|
|
assert.Greater(t, sessionEndDuration, time.Second)
|
|
|
|
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
|
|
// 1 second < time until token expires <= max duration for tokens from IDP
|
|
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
|
|
assert.Greater(t, tokenExpiryDuration, time.Second)
|
|
|
|
assert.True(t, data.Session.Active)
|
|
assert.True(t, data.Session.TimeoutAt.IsZero())
|
|
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
|
|
}
|
|
|
|
func TestSession_WithInactivity(t *testing.T) {
|
|
cfg := mock.Config()
|
|
cfg.Session.Refresh = true
|
|
cfg.Session.Inactivity = true
|
|
cfg.Session.InactivityTimeout = 10 * time.Minute
|
|
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
login(t, rpClient, idp)
|
|
|
|
resp := sessionInfo(t, idp, rpClient)
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
var data session.MetadataVerbose
|
|
err := json.Unmarshal([]byte(resp.Body), &data)
|
|
assert.NoError(t, err)
|
|
|
|
maxDelta := 5 * time.Second
|
|
|
|
assert.True(t, data.Session.Active)
|
|
assert.False(t, data.Session.TimeoutAt.IsZero())
|
|
|
|
expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout)
|
|
assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta)
|
|
|
|
actualTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second
|
|
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(actualTimeoutDuration), maxDelta)
|
|
}
|
|
|
|
func TestSession_WithRefresh(t *testing.T) {
|
|
cfg := mock.Config()
|
|
cfg.Session.Refresh = true
|
|
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
idp.ProviderHandler.TokenDuration = 5 * time.Minute
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
login(t, rpClient, idp)
|
|
|
|
resp := sessionInfo(t, idp, rpClient)
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
var data session.MetadataVerboseWithRefresh
|
|
err := json.Unmarshal([]byte(resp.Body), &data)
|
|
assert.NoError(t, err)
|
|
|
|
allowedSkew := 5 * time.Second
|
|
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
|
|
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
|
|
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
|
|
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
|
|
|
|
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
|
|
// 1 second < time until session ends <= configured max session lifetime
|
|
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
|
|
assert.Greater(t, sessionEndDuration, time.Second)
|
|
|
|
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
|
|
// 1 second < time until token expires <= max duration for tokens from IDP
|
|
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
|
|
assert.Greater(t, tokenExpiryDuration, time.Second)
|
|
|
|
// auto refresh is not enabled
|
|
assert.Less(t, data.Tokens.NextAutoRefreshInSeconds, data.Tokens.ExpireInSeconds)
|
|
assert.Equal(t, data.Tokens.NextAutoRefreshInSeconds, int64(-1))
|
|
|
|
assert.True(t, data.Tokens.RefreshCooldown)
|
|
// 1 second < refresh cooldown <= minimum refresh interval
|
|
assert.LessOrEqual(t, data.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
|
|
assert.Greater(t, data.Tokens.RefreshCooldownSeconds, int64(1))
|
|
|
|
assert.True(t, data.Session.Active)
|
|
assert.True(t, data.Session.TimeoutAt.IsZero())
|
|
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
|
|
}
|
|
|
|
func TestSession_WithRefreshAuto(t *testing.T) {
|
|
cfg := mock.Config()
|
|
cfg.Session.Refresh = true
|
|
cfg.Session.RefreshAuto = true
|
|
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
idp.ProviderHandler.TokenDuration = 5 * time.Minute
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
login(t, rpClient, idp)
|
|
|
|
resp := sessionInfo(t, idp, rpClient)
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
var data session.MetadataVerboseWithRefresh
|
|
err := json.Unmarshal([]byte(resp.Body), &data)
|
|
assert.NoError(t, err)
|
|
|
|
// 1 second < next token refresh <= seconds until token expires
|
|
assert.LessOrEqual(t, data.Tokens.NextAutoRefreshInSeconds, data.Tokens.ExpireInSeconds)
|
|
assert.Greater(t, data.Tokens.NextAutoRefreshInSeconds, int64(1))
|
|
}
|
|
|
|
func TestPing(t *testing.T) {
|
|
cfg := mock.Config()
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
rpClient := idp.RelyingPartyClient()
|
|
resp := get(t, rpClient, idp.RelyingPartyServer.URL+"/oauth2/ping")
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
assert.Equal(t, "pong", resp.Body)
|
|
}
|
|
|
|
func localLogin(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
|
|
// First, run /oauth2/login to set cookies
|
|
loginURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/login")
|
|
assert.NoError(t, err)
|
|
|
|
resp := get(t, rpClient, loginURL.String())
|
|
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
|
|
|
cookies := rpClient.Jar.Cookies(loginURL)
|
|
sessionCookie := getCookieFromJar(cookie.Session, cookies)
|
|
loginCookie := getCookieFromJar(cookie.Login, cookies)
|
|
loginLegacyCookie := getCookieFromJar(cookie.LoginLegacy, cookies)
|
|
|
|
assert.Nil(t, sessionCookie)
|
|
assert.NotNil(t, loginCookie)
|
|
assert.NotNil(t, loginLegacyCookie)
|
|
|
|
return resp
|
|
}
|
|
|
|
func authorize(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
|
|
resp := localLogin(t, rpClient, idp)
|
|
|
|
// Follow redirect to authorize with identity provider
|
|
resp = get(t, rpClient, resp.Location.String())
|
|
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
|
|
|
return resp
|
|
}
|
|
|
|
func callback(t *testing.T, rpClient *http.Client, authorizeResponse response) *http.Cookie {
|
|
// Get callback URL after successful auth
|
|
callbackURL := authorizeResponse.Location
|
|
|
|
// Follow redirect to callback
|
|
resp := get(t, rpClient, callbackURL.String())
|
|
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
|
|
|
cookies := rpClient.Jar.Cookies(callbackURL)
|
|
sessionCookie := getCookieFromJar(cookie.Session, cookies)
|
|
loginCookie := getCookieFromJar(cookie.Login, cookies)
|
|
loginLegacyCookie := getCookieFromJar(cookie.LoginLegacy, cookies)
|
|
|
|
assert.NotNil(t, sessionCookie)
|
|
assert.Nil(t, loginCookie)
|
|
assert.Nil(t, loginLegacyCookie)
|
|
|
|
return sessionCookie
|
|
}
|
|
|
|
func login(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) *http.Cookie {
|
|
resp := authorize(t, rpClient, idp)
|
|
return callback(t, rpClient, resp)
|
|
}
|
|
|
|
func selfInitiatedLogout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider, redirectAfterLogout ...string) response {
|
|
// Request self-initiated logout
|
|
logoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout")
|
|
assert.NoError(t, err)
|
|
|
|
if len(redirectAfterLogout) > 0 {
|
|
v := url.Values{}
|
|
v.Set(urlpkg.RedirectQueryParameter, redirectAfterLogout[0])
|
|
logoutURL.RawQuery = v.Encode()
|
|
}
|
|
|
|
resp := get(t, rpClient, logoutURL.String())
|
|
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
|
|
|
cookies := rpClient.Jar.Cookies(logoutURL)
|
|
sessionCookie := getCookieFromJar(cookie.Session, cookies)
|
|
|
|
assert.Nil(t, sessionCookie)
|
|
|
|
// Get endsession endpoint after local logout
|
|
endsessionURL := resp.Location
|
|
|
|
idpserverURL, err := url.Parse(idp.ProviderServer.URL)
|
|
assert.NoError(t, err)
|
|
|
|
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout")
|
|
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
|
|
assert.NoError(t, err)
|
|
|
|
endsessionParams := endsessionURL.Query()
|
|
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
|
|
assert.Equal(t, "/endsession", endsessionURL.Path)
|
|
assert.Equal(t, expectedLogoutCallbackURL, endsessionParams.Get("post_logout_redirect_uri"))
|
|
assert.NotEmpty(t, endsessionParams.Get("id_token_hint"))
|
|
assert.NotEmpty(t, endsessionParams.Get("state"))
|
|
|
|
return resp
|
|
}
|
|
|
|
func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider, redirectAfterLogout ...string) {
|
|
// Get endsession endpoint after local logout
|
|
resp := selfInitiatedLogout(t, rpClient, idp, redirectAfterLogout...)
|
|
expectedState := resp.Location.Query().Get("state")
|
|
|
|
// Follow redirect to endsession endpoint at identity provider
|
|
resp = get(t, rpClient, resp.Location.String())
|
|
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
|
|
|
logoutCallbackURI := resp.Location
|
|
|
|
// Assert state for callback equals state sent in initial logout request
|
|
actualState := logoutCallbackURI.Query().Get("state")
|
|
assert.NotEmpty(t, actualState)
|
|
assert.Equal(t, expectedState, actualState)
|
|
|
|
// Assert post-logout redirect URI after successful logout at identity provider
|
|
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout")
|
|
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
|
|
assert.NoError(t, err)
|
|
|
|
assert.Contains(t, logoutCallbackURI.String(), expectedLogoutCallbackURL)
|
|
assert.Equal(t, "/oauth2/logout/callback", logoutCallbackURI.Path)
|
|
|
|
// Follow redirect back to logout callback
|
|
resp = get(t, rpClient, logoutCallbackURI.String())
|
|
|
|
// Get post-logout redirect URI after redirect back to logout callback
|
|
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
|
|
|
expectedRedirect := "https://google.com"
|
|
if len(redirectAfterLogout) > 0 {
|
|
expectedRedirect = redirectAfterLogout[0]
|
|
}
|
|
assert.Equal(t, expectedRedirect, resp.Location.String())
|
|
|
|
cookies := rpClient.Jar.Cookies(logoutCallbackURI)
|
|
sessionCookie := getCookieFromJar(cookie.Session, cookies)
|
|
|
|
assert.Nil(t, sessionCookie)
|
|
}
|
|
|
|
func localLogout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
|
|
logoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/local")
|
|
assert.NoError(t, err)
|
|
|
|
resp := get(t, rpClient, logoutURL.String())
|
|
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
|
|
|
|
cookies := rpClient.Jar.Cookies(logoutURL)
|
|
sessionCookie := getCookieFromJar(cookie.Session, cookies)
|
|
|
|
assert.Nil(t, sessionCookie)
|
|
|
|
return resp
|
|
}
|
|
|
|
func sessionInfo(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) response {
|
|
sessionInfoURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/session")
|
|
assert.NoError(t, err)
|
|
|
|
return get(t, rpClient, sessionInfoURL.String())
|
|
}
|
|
|
|
func sessionRefresh(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) response {
|
|
sessionRefreshURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/session/refresh")
|
|
assert.NoError(t, err)
|
|
|
|
return post(t, rpClient, sessionRefreshURL.String())
|
|
}
|
|
|
|
func waitForRefreshCooldownTimer(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) {
|
|
timeout := time.After(5 * time.Second)
|
|
ticker := time.Tick(500 * time.Millisecond)
|
|
for {
|
|
select {
|
|
case <-timeout:
|
|
assert.Fail(t, "refresh cooldown timer exceeded timeout")
|
|
case <-ticker:
|
|
resp := sessionInfo(t, idp, rpClient)
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
|
|
var temp session.MetadataVerboseWithRefresh
|
|
err := json.Unmarshal([]byte(resp.Body), &temp)
|
|
assert.NoError(t, err)
|
|
|
|
if !temp.Tokens.RefreshCooldown {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type response struct {
|
|
Body string
|
|
Location *url.URL
|
|
StatusCode int
|
|
}
|
|
|
|
type header struct {
|
|
key, value string
|
|
}
|
|
|
|
func get(t *testing.T, client *http.Client, url string, headers ...header) response {
|
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
|
assert.NoError(t, err)
|
|
|
|
for _, h := range headers {
|
|
req.Header.Set(h.key, h.value)
|
|
}
|
|
|
|
resp, err := client.Do(req)
|
|
assert.NoError(t, err)
|
|
|
|
location, err := resp.Location()
|
|
if !errors.Is(http.ErrNoLocation, err) {
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
return response{
|
|
Body: body(t, resp),
|
|
Location: location,
|
|
StatusCode: resp.StatusCode,
|
|
}
|
|
}
|
|
|
|
func post(t *testing.T, client *http.Client, url string) response {
|
|
req, err := http.NewRequest(http.MethodPost, url, nil)
|
|
assert.NoError(t, err)
|
|
|
|
resp, err := client.Do(req)
|
|
assert.NoError(t, err)
|
|
|
|
location, err := resp.Location()
|
|
if !errors.Is(http.ErrNoLocation, err) {
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
return response{
|
|
Body: body(t, resp),
|
|
Location: location,
|
|
StatusCode: resp.StatusCode,
|
|
}
|
|
}
|
|
|
|
func body(t *testing.T, resp *http.Response) string {
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
assert.NoError(t, err)
|
|
|
|
return string(body)
|
|
}
|
|
|
|
type upstream struct {
|
|
Server *httptest.Server
|
|
URL *url.URL
|
|
idp *mock.IdentityProvider
|
|
reverseProxyURL *url.URL
|
|
requestCallback func(r *http.Request)
|
|
}
|
|
|
|
func (u *upstream) SetIdentityProvider(idp *mock.IdentityProvider) {
|
|
u.idp = idp
|
|
u.setReverseProxyUrl(idp.RelyingPartyServer.URL)
|
|
}
|
|
|
|
func (u *upstream) setReverseProxyUrl(raw string) {
|
|
parsed, err := url.Parse(raw)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
u.reverseProxyURL = parsed
|
|
}
|
|
|
|
func (u *upstream) hasValidToken(r *http.Request) bool {
|
|
authHeader := r.Header.Get("Authorization")
|
|
token := strings.TrimPrefix(authHeader, "Bearer ")
|
|
if len(token) <= 0 {
|
|
return false
|
|
}
|
|
|
|
jwks, err := u.idp.ProviderHandler.Provider.GetPublicJwkSet(r.Context())
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
opts := []jwt.ParseOption{
|
|
jwt.WithValidate(true),
|
|
jwt.WithKeySet(*jwks),
|
|
jwt.WithIssuer(u.idp.OpenIDConfig.Provider().Issuer()),
|
|
jwt.WithAudience(u.idp.OpenIDConfig.Client().ClientID()),
|
|
}
|
|
|
|
_, err = jwt.ParseString(token, opts...)
|
|
return err == nil
|
|
}
|
|
|
|
func newUpstream(t *testing.T) *upstream {
|
|
u := new(upstream)
|
|
u.requestCallback = func(r *http.Request) {}
|
|
|
|
upstreamHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
u.requestCallback(r)
|
|
|
|
// Host should match the original authority from the ingress used to reach Wonderwall
|
|
assert.Equal(t, u.reverseProxyURL.Host, r.Host)
|
|
assert.NotEqual(t, u.URL.Host, r.Host)
|
|
|
|
if u.hasValidToken(r) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("ok"))
|
|
} else {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
_, _ = w.Write([]byte("not ok"))
|
|
}
|
|
})
|
|
server := httptest.NewServer(upstreamHandler)
|
|
|
|
upstreamURL, err := url.Parse(server.URL)
|
|
assert.NoError(t, err)
|
|
|
|
u.Server = server
|
|
u.URL = upstreamURL
|
|
return u
|
|
}
|
|
|
|
func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie {
|
|
for _, c := range cookies {
|
|
if c.Name == name {
|
|
return c
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|