Files
wonderwall/pkg/handler/handler_test.go
2023-04-29 08:42:29 +02:00

742 lines
23 KiB
Go

package handler_test
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/session"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
func TestLogin(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
resp := localLogin(t, rpClient, idp)
loginURL := resp.Location
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/login")
expectedCallbackURL, err := urlpkg.LoginCallback(req)
assert.NoError(t, err)
assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host))
assert.Equal(t, "/authorize", loginURL.Path)
assert.Equal(t, idp.OpenIDConfig.Client().ACRValues(), loginURL.Query().Get("acr_values"))
assert.Equal(t, idp.OpenIDConfig.Client().UILocales(), loginURL.Query().Get("ui_locales"))
assert.Equal(t, idp.OpenIDConfig.Client().ClientID(), loginURL.Query().Get("client_id"))
assert.Equal(t, expectedCallbackURL, loginURL.Query().Get("redirect_uri"))
assert.Equal(t, "S256", loginURL.Query().Get("code_challenge_method"))
assert.ElementsMatch(t, idp.OpenIDConfig.Client().Scopes(), strings.Split(loginURL.Query().Get("scope"), " "))
assert.NotEmpty(t, loginURL.Query().Get("state"))
assert.NotEmpty(t, loginURL.Query().Get("nonce"))
assert.NotEmpty(t, loginURL.Query().Get("code_challenge"))
resp = get(t, rpClient, loginURL.String())
assert.Equal(t, http.StatusFound, resp.StatusCode)
callbackURL := resp.Location
assert.Equal(t, loginURL.Query().Get("state"), callbackURL.Query().Get("state"))
assert.NotEmpty(t, callbackURL.Query().Get("code"))
}
func TestCallback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
}
func TestCallback_SessionStateRequired(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
idp.OpenIDConfig.TestProvider.WithCheckSessionIFrameSupport(idp.ProviderServer.URL + "/checksession")
defer idp.Close()
rpClient := idp.RelyingPartyClient()
resp := authorize(t, rpClient, idp)
// Get callback URL after successful auth
params := resp.Location.Query()
sessionState := params.Get("session_state")
assert.NotEmpty(t, sessionState)
callback(t, rpClient, resp)
}
func TestLogout(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := selfInitiatedLogout(t, rpClient, idp)
// Get endsession endpoint after local logout
endsessionURL := resp.Location
idpserverURL, err := url.Parse(idp.ProviderServer.URL)
assert.NoError(t, err)
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout/callback")
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
assert.NoError(t, err)
endsessionParams := endsessionURL.Query()
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
assert.Equal(t, "/endsession", endsessionURL.Path)
assert.Equal(t, []string{expectedLogoutCallbackURL}, endsessionParams["post_logout_redirect_uri"])
assert.NotEmpty(t, endsessionParams["id_token_hint"])
}
func TestLogoutLocal(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
localLogout(t, rpClient, idp)
}
func TestLogoutCallback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
logout(t, rpClient, idp)
}
func TestFrontChannelLogout(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
idp.OpenIDConfig.TestProvider.WithFrontChannelLogoutSupport()
defer idp.Close()
rpClient := idp.RelyingPartyClient()
sessionCookie := login(t, rpClient, idp)
// Trigger front-channel logout
sid := func(r *http.Request) string {
r.AddCookie(sessionCookie)
data, err := idp.RelyingPartyHandler.SessionManager.Get(r)
assert.NoError(t, err)
return data.ExternalSessionID()
}
frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/frontchannel")
assert.NoError(t, err)
req := idp.GetRequest(frontchannelLogoutURL.String())
values := url.Values{}
values.Add("sid", sid(req))
values.Add("iss", idp.OpenIDConfig.Provider().Issuer())
frontchannelLogoutURL.RawQuery = values.Encode()
resp := get(t, rpClient, frontchannelLogoutURL.String())
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestSessionRefresh(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
// get initial session info
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerboseWithRefresh
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
// wait until refresh cooldown has reached zero before refresh
waitForRefreshCooldownTimer(t, idp, rpClient)
resp = sessionRefresh(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var refreshedData session.MetadataVerboseWithRefresh
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
assert.NoError(t, err)
// session create and end times should be unchanged
assert.WithinDuration(t, data.Session.CreatedAt, refreshedData.Session.CreatedAt, 0)
assert.WithinDuration(t, data.Session.EndsAt, refreshedData.Session.EndsAt, 0)
// token expiration and refresh times should be later than before
assert.True(t, refreshedData.Tokens.ExpireAt.After(data.Tokens.ExpireAt))
assert.True(t, refreshedData.Tokens.RefreshedAt.After(data.Tokens.RefreshedAt))
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), refreshedData.Tokens.ExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), refreshedData.Tokens.RefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(refreshedData.Session.EndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(refreshedData.Tokens.ExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
// 1 second < next token refresh <= seconds until token expires
assert.LessOrEqual(t, refreshedData.Tokens.NextAutoRefreshInSeconds, refreshedData.Tokens.ExpireInSeconds)
assert.Greater(t, refreshedData.Tokens.NextAutoRefreshInSeconds, int64(1))
assert.True(t, refreshedData.Tokens.RefreshCooldown)
// 1 second < refresh cooldown <= minimum refresh interval
assert.LessOrEqual(t, refreshedData.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
assert.Greater(t, refreshedData.Tokens.RefreshCooldownSeconds, int64(1))
assert.True(t, data.Session.Active)
assert.True(t, refreshedData.Session.Active)
assert.True(t, data.Session.TimeoutAt.IsZero())
assert.True(t, refreshedData.Session.TimeoutAt.IsZero())
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
assert.Equal(t, int64(-1), refreshedData.Session.TimeoutInSeconds)
}
func TestSessionRefresh_Disabled(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = false
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionRefresh(t, idp, rpClient)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
}
func TestSessionRefresh_WithInactivity(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
cfg.Session.Inactivity = true
cfg.Session.InactivityTimeout = 10 * time.Minute
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
// get initial session info
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerboseWithRefresh
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
// wait until refresh cooldown has reached zero before refresh
waitForRefreshCooldownTimer(t, idp, rpClient)
resp = sessionRefresh(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var refreshedData session.MetadataVerboseWithRefresh
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
assert.NoError(t, err)
maxDelta := 5 * time.Second
assert.True(t, data.Session.Active)
assert.True(t, refreshedData.Session.Active)
assert.False(t, data.Session.TimeoutAt.IsZero())
assert.False(t, refreshedData.Session.TimeoutAt.IsZero())
expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout)
assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta)
assert.WithinDuration(t, expectedTimeoutAt, refreshedData.Session.TimeoutAt, maxDelta)
assert.True(t, refreshedData.Session.TimeoutAt.After(data.Session.TimeoutAt))
previousTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(previousTimeoutDuration), maxDelta)
refreshedTimeoutDuration := time.Duration(refreshedData.Session.TimeoutInSeconds) * time.Second
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(refreshedTimeoutDuration), maxDelta)
}
func TestSession(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Minute
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerbose
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
assert.True(t, data.Session.Active)
assert.True(t, data.Session.TimeoutAt.IsZero())
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
}
func TestSession_WithInactivity(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
cfg.Session.Inactivity = true
cfg.Session.InactivityTimeout = 10 * time.Minute
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerbose
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
maxDelta := 5 * time.Second
assert.True(t, data.Session.Active)
assert.False(t, data.Session.TimeoutAt.IsZero())
expectedTimeoutAt := time.Now().Add(cfg.Session.InactivityTimeout)
assert.WithinDuration(t, expectedTimeoutAt, data.Session.TimeoutAt, maxDelta)
actualTimeoutDuration := time.Duration(data.Session.TimeoutInSeconds) * time.Second
assert.WithinDuration(t, expectedTimeoutAt, time.Now().Add(actualTimeoutDuration), maxDelta)
}
func TestSession_WithRefresh(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Minute
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerboseWithRefresh
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
// 1 second < next token refresh <= seconds until token expires
assert.LessOrEqual(t, data.Tokens.NextAutoRefreshInSeconds, data.Tokens.ExpireInSeconds)
assert.Greater(t, data.Tokens.NextAutoRefreshInSeconds, int64(1))
assert.True(t, data.Tokens.RefreshCooldown)
// 1 second < refresh cooldown <= minimum refresh interval
assert.LessOrEqual(t, data.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
assert.Greater(t, data.Tokens.RefreshCooldownSeconds, int64(1))
assert.True(t, data.Session.Active)
assert.True(t, data.Session.TimeoutAt.IsZero())
assert.Equal(t, int64(-1), data.Session.TimeoutInSeconds)
}
func TestPing(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
resp := get(t, rpClient, idp.RelyingPartyServer.URL+"/oauth2/ping")
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "pong", resp.Body)
}
func localLogin(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
// First, run /oauth2/login to set cookies
loginURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/login")
assert.NoError(t, err)
resp := get(t, rpClient, loginURL.String())
assert.Equal(t, http.StatusFound, resp.StatusCode)
cookies := rpClient.Jar.Cookies(loginURL)
sessionCookie := getCookieFromJar(cookie.Session, cookies)
loginCookie := getCookieFromJar(cookie.Login, cookies)
loginLegacyCookie := getCookieFromJar(cookie.LoginLegacy, cookies)
assert.Nil(t, sessionCookie)
assert.NotNil(t, loginCookie)
assert.NotNil(t, loginLegacyCookie)
return resp
}
func authorize(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
resp := localLogin(t, rpClient, idp)
// Follow redirect to authorize with identity provider
resp = get(t, rpClient, resp.Location.String())
assert.Equal(t, http.StatusFound, resp.StatusCode)
return resp
}
func callback(t *testing.T, rpClient *http.Client, authorizeResponse response) *http.Cookie {
// Get callback URL after successful auth
callbackURL := authorizeResponse.Location
// Follow redirect to callback
resp := get(t, rpClient, callbackURL.String())
assert.Equal(t, http.StatusFound, resp.StatusCode)
cookies := rpClient.Jar.Cookies(callbackURL)
sessionCookie := getCookieFromJar(cookie.Session, cookies)
loginCookie := getCookieFromJar(cookie.Login, cookies)
loginLegacyCookie := getCookieFromJar(cookie.LoginLegacy, cookies)
assert.NotNil(t, sessionCookie)
assert.Nil(t, loginCookie)
assert.Nil(t, loginLegacyCookie)
return sessionCookie
}
func login(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) *http.Cookie {
resp := authorize(t, rpClient, idp)
return callback(t, rpClient, resp)
}
func selfInitiatedLogout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
// Request self-initiated logout
logoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout")
assert.NoError(t, err)
resp := get(t, rpClient, logoutURL.String())
assert.Equal(t, http.StatusFound, resp.StatusCode)
cookies := rpClient.Jar.Cookies(logoutURL)
sessionCookie := getCookieFromJar(cookie.Session, cookies)
assert.Nil(t, sessionCookie)
return resp
}
func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) {
// Get endsession endpoint after local logout
resp := selfInitiatedLogout(t, rpClient, idp)
// Follow redirect to endsession endpoint at identity provider
resp = get(t, rpClient, resp.Location.String())
assert.Equal(t, http.StatusFound, resp.StatusCode)
// Get post-logout redirect URI after successful logout at identity provider
logoutCallbackURI := resp.Location
req := idp.GetRequest(resp.Location.String())
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
assert.NoError(t, err)
assert.Contains(t, logoutCallbackURI.String(), expectedLogoutCallbackURL)
assert.Equal(t, "/oauth2/logout/callback", logoutCallbackURI.Path)
// Follow redirect back to logout callback
resp = get(t, rpClient, logoutCallbackURI.String())
assert.Equal(t, http.StatusFound, resp.StatusCode)
// Get post-logout redirect URI after redirect back to logout callback
assert.Equal(t, "https://google.com", resp.Location.String())
cookies := rpClient.Jar.Cookies(logoutCallbackURI)
sessionCookie := getCookieFromJar(cookie.Session, cookies)
assert.Nil(t, sessionCookie)
}
func localLogout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
logoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/local")
assert.NoError(t, err)
resp := get(t, rpClient, logoutURL.String())
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
cookies := rpClient.Jar.Cookies(logoutURL)
sessionCookie := getCookieFromJar(cookie.Session, cookies)
assert.Nil(t, sessionCookie)
return resp
}
func sessionInfo(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) response {
sessionInfoURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/session")
assert.NoError(t, err)
return get(t, rpClient, sessionInfoURL.String())
}
func sessionRefresh(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) response {
sessionRefreshURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/session/refresh")
assert.NoError(t, err)
return post(t, rpClient, sessionRefreshURL.String())
}
func waitForRefreshCooldownTimer(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) {
timeout := time.After(5 * time.Second)
ticker := time.Tick(500 * time.Millisecond)
for {
select {
case <-timeout:
assert.Fail(t, "refresh cooldown timer exceeded timeout")
case <-ticker:
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var temp session.MetadataVerboseWithRefresh
err := json.Unmarshal([]byte(resp.Body), &temp)
assert.NoError(t, err)
if !temp.Tokens.RefreshCooldown {
return
}
}
}
}
type response struct {
Body string
Location *url.URL
StatusCode int
}
func get(t *testing.T, client *http.Client, url string) response {
resp, err := client.Get(url)
assert.NoError(t, err)
location, err := resp.Location()
if !errors.Is(http.ErrNoLocation, err) {
assert.NoError(t, err)
}
return response{
Body: body(t, resp),
Location: location,
StatusCode: resp.StatusCode,
}
}
func getWithHeaders(t *testing.T, client *http.Client, url string, headers map[string]string) response {
req, err := http.NewRequest(http.MethodGet, url, nil)
assert.NoError(t, err)
for key, value := range headers {
req.Header.Set(key, value)
}
resp, err := client.Do(req)
assert.NoError(t, err)
location, err := resp.Location()
if !errors.Is(http.ErrNoLocation, err) {
assert.NoError(t, err)
}
return response{
Body: body(t, resp),
Location: location,
StatusCode: resp.StatusCode,
}
}
func post(t *testing.T, client *http.Client, url string) response {
req, err := http.NewRequest(http.MethodPost, url, nil)
assert.NoError(t, err)
resp, err := client.Do(req)
assert.NoError(t, err)
location, err := resp.Location()
if !errors.Is(http.ErrNoLocation, err) {
assert.NoError(t, err)
}
return response{
Body: body(t, resp),
Location: location,
StatusCode: resp.StatusCode,
}
}
func body(t *testing.T, resp *http.Response) string {
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
return string(body)
}
type upstream struct {
Server *httptest.Server
URL *url.URL
idp *mock.IdentityProvider
reverseProxyURL *url.URL
requestCallback func(r *http.Request)
}
func (u *upstream) SetIdentityProvider(idp *mock.IdentityProvider) {
u.idp = idp
u.setReverseProxyUrl(idp.RelyingPartyServer.URL)
}
func (u *upstream) setReverseProxyUrl(raw string) {
parsed, err := url.Parse(raw)
if err != nil {
panic(err)
}
u.reverseProxyURL = parsed
}
func (u *upstream) hasValidToken(r *http.Request) bool {
authHeader := r.Header.Get("Authorization")
token := strings.TrimPrefix(authHeader, "Bearer ")
if len(token) <= 0 {
return false
}
jwks, err := u.idp.ProviderHandler.Provider.GetPublicJwkSet(r.Context())
if err != nil {
panic(err)
}
opts := []jwt.ParseOption{
jwt.WithValidate(true),
jwt.WithKeySet(*jwks),
jwt.WithIssuer(u.idp.OpenIDConfig.Provider().Issuer()),
jwt.WithAudience(u.idp.OpenIDConfig.Client().ClientID()),
}
_, err = jwt.ParseString(token, opts...)
return err == nil
}
func newUpstream(t *testing.T) *upstream {
u := new(upstream)
u.requestCallback = func(r *http.Request) {}
upstreamHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
u.requestCallback(r)
// Host should match the original authority from the ingress used to reach Wonderwall
assert.Equal(t, u.reverseProxyURL.Host, r.Host)
assert.NotEqual(t, u.URL.Host, r.Host)
if u.hasValidToken(r) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
} else {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("not ok"))
}
})
server := httptest.NewServer(upstreamHandler)
upstreamURL, err := url.Parse(server.URL)
assert.NoError(t, err)
u.Server = server
u.URL = upstreamURL
return u
}
func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie {
for _, c := range cookies {
if c.Name == name {
return c
}
}
return nil
}