Files
wonderwall/pkg/handler/handler_test.go
Trong Huu Nguyen 80738f2a4b fix(handler/url): use base64 encoding for redirects to preserve query parameters
Load balancers or reverse proxies may rewrite or modify the Location
header and unescape its value, which would result in redirects not
preserving the original set of query parameters. This was especially
evident for autologins where we need to redirect to `/oauth2/login` with
the `redirect` parameter containing the original requested URL so that
the end-user ultimately ends up at the latter URL.

We avoid this issue by base64-encoding the original URL, before passing
it along as the intended redirect for the login route.
To preserve existing behaviour, we use a separate query parameter
for the `/oauth2/login`-endpoint that accepts and handles base64-encoded
values.
2022-09-19 11:51:30 +02:00

718 lines
22 KiB
Go

package handler_test
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/nais/wonderwall/pkg/cookie"
urlpkg "github.com/nais/wonderwall/pkg/handler/url"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/session"
)
func TestHandler_Login(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
resp := localLogin(t, rpClient, idp)
loginURL := resp.Location
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/login")
expectedCallbackURL, err := urlpkg.LoginCallbackURL(req)
assert.NoError(t, err)
assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host))
assert.Equal(t, "/authorize", loginURL.Path)
assert.Equal(t, idp.OpenIDConfig.Client().ACRValues(), loginURL.Query().Get("acr_values"))
assert.Equal(t, idp.OpenIDConfig.Client().UILocales(), loginURL.Query().Get("ui_locales"))
assert.Equal(t, idp.OpenIDConfig.Client().ClientID(), loginURL.Query().Get("client_id"))
assert.Equal(t, expectedCallbackURL, loginURL.Query().Get("redirect_uri"))
assert.Equal(t, "S256", loginURL.Query().Get("code_challenge_method"))
assert.ElementsMatch(t, idp.OpenIDConfig.Client().Scopes(), strings.Split(loginURL.Query().Get("scope"), " "))
assert.NotEmpty(t, loginURL.Query().Get("state"))
assert.NotEmpty(t, loginURL.Query().Get("nonce"))
assert.NotEmpty(t, loginURL.Query().Get("code_challenge"))
resp = get(t, rpClient, loginURL.String())
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
callbackURL := resp.Location
assert.Equal(t, loginURL.Query().Get("state"), callbackURL.Query().Get("state"))
assert.NotEmpty(t, callbackURL.Query().Get("code"))
}
func TestHandler_Callback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
}
func TestHandler_Logout(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := localLogout(t, rpClient, idp)
// Get endsession endpoint after local logout
endsessionURL := resp.Location
idpserverURL, err := url.Parse(idp.ProviderServer.URL)
assert.NoError(t, err)
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout")
expectedLogoutCallbackURL, err := urlpkg.LogoutCallbackURL(req)
assert.NoError(t, err)
endsessionParams := endsessionURL.Query()
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
assert.Equal(t, "/endsession", endsessionURL.Path)
assert.Equal(t, []string{expectedLogoutCallbackURL}, endsessionParams["post_logout_redirect_uri"])
assert.NotEmpty(t, endsessionParams["id_token_hint"])
}
func TestHandler_LogoutCallback(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
logout(t, rpClient, idp)
}
func TestHandler_FrontChannelLogout(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
idp.OpenIDConfig.TestProvider.WithFrontChannelLogoutSupport()
defer idp.Close()
rpClient := idp.RelyingPartyClient()
sessionCookie := login(t, rpClient, idp)
// Trigger front-channel logout
sid := func(r *http.Request) string {
ciphertext, err := base64.StdEncoding.DecodeString(sessionCookie.Value)
assert.NoError(t, err)
sessionKey, err := idp.RelyingPartyHandler.GetCrypter().Decrypt(ciphertext)
assert.NoError(t, err)
data, err := idp.RelyingPartyHandler.GetSessions().GetForKey(r, string(sessionKey))
assert.NoError(t, err)
return data.ExternalSessionID
}
frontchannelLogoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout/frontchannel")
assert.NoError(t, err)
req := idp.GetRequest(frontchannelLogoutURL.String())
values := url.Values{}
values.Add("sid", sid(req))
values.Add("iss", idp.OpenIDConfig.Provider().Issuer())
frontchannelLogoutURL.RawQuery = values.Encode()
resp := get(t, rpClient, frontchannelLogoutURL.String())
assert.Equal(t, http.StatusOK, resp.StatusCode)
}
func TestHandler_SessionStateRequired(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
idp.OpenIDConfig.TestProvider.WithCheckSessionIFrameSupport(idp.ProviderServer.URL + "/checksession")
defer idp.Close()
rpClient := idp.RelyingPartyClient()
resp := authorize(t, rpClient, idp)
// Get callback URL after successful auth
params := resp.Location.Query()
sessionState := params.Get("session_state")
assert.NotEmpty(t, sessionState)
}
func TestHandler_SessionInfo(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Minute
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerbose
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
}
func TestHandler_SessionInfo_WithRefresh(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Minute
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerboseWithRefresh
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now(), data.Session.CreatedAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(cfg.Session.MaxLifetime), data.Session.EndsAt, allowedSkew)
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), data.Tokens.ExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), data.Tokens.RefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(data.Session.EndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(data.Tokens.ExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
// 1 second < next token refresh <= seconds until token expires
assert.LessOrEqual(t, data.Tokens.NextAutoRefreshInSeconds, data.Tokens.ExpireInSeconds)
assert.Greater(t, data.Tokens.NextAutoRefreshInSeconds, int64(1))
assert.True(t, data.Tokens.RefreshCooldown)
// 1 second < refresh cooldown <= minimum refresh interval
assert.LessOrEqual(t, data.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
assert.Greater(t, data.Tokens.RefreshCooldownSeconds, int64(1))
}
func TestHandler_SessionRefresh(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = true
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
// get initial session info
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var data session.MetadataVerboseWithRefresh
err := json.Unmarshal([]byte(resp.Body), &data)
assert.NoError(t, err)
// wait until refresh cooldown has reached zero before refresh
func() {
timeout := time.After(5 * time.Second)
ticker := time.Tick(500 * time.Millisecond)
for {
select {
case <-timeout:
assert.Fail(t, "refresh cooldown timer exceeded timeout")
case <-ticker:
resp := sessionInfo(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var temp session.MetadataVerboseWithRefresh
err = json.Unmarshal([]byte(resp.Body), &temp)
assert.NoError(t, err)
if !temp.Tokens.RefreshCooldown {
return
}
}
}
}()
resp = sessionRefresh(t, idp, rpClient)
assert.Equal(t, http.StatusOK, resp.StatusCode)
var refreshedData session.MetadataVerboseWithRefresh
err = json.Unmarshal([]byte(resp.Body), &refreshedData)
assert.NoError(t, err)
// session create and end times should be unchanged
assert.WithinDuration(t, data.Session.CreatedAt, refreshedData.Session.CreatedAt, 0)
assert.WithinDuration(t, data.Session.EndsAt, refreshedData.Session.EndsAt, 0)
// token expiration and refresh times should be later than before
assert.True(t, refreshedData.Tokens.ExpireAt.After(data.Tokens.ExpireAt))
assert.True(t, refreshedData.Tokens.RefreshedAt.After(data.Tokens.RefreshedAt))
allowedSkew := 5 * time.Second
assert.WithinDuration(t, time.Now().Add(idp.ProviderHandler.TokenDuration), refreshedData.Tokens.ExpireAt, allowedSkew)
assert.WithinDuration(t, time.Now(), refreshedData.Tokens.RefreshedAt, allowedSkew)
sessionEndDuration := time.Duration(refreshedData.Session.EndsInSeconds) * time.Second
// 1 second < time until session ends <= configured max session lifetime
assert.LessOrEqual(t, sessionEndDuration, cfg.Session.MaxLifetime)
assert.Greater(t, sessionEndDuration, time.Second)
tokenExpiryDuration := time.Duration(refreshedData.Tokens.ExpireInSeconds) * time.Second
// 1 second < time until token expires <= max duration for tokens from IDP
assert.LessOrEqual(t, tokenExpiryDuration, idp.ProviderHandler.TokenDuration)
assert.Greater(t, tokenExpiryDuration, time.Second)
// 1 second < next token refresh <= seconds until token expires
assert.LessOrEqual(t, refreshedData.Tokens.NextAutoRefreshInSeconds, refreshedData.Tokens.ExpireInSeconds)
assert.Greater(t, refreshedData.Tokens.NextAutoRefreshInSeconds, int64(1))
assert.True(t, refreshedData.Tokens.RefreshCooldown)
// 1 second < refresh cooldown <= minimum refresh interval
assert.LessOrEqual(t, refreshedData.Tokens.RefreshCooldownSeconds, session.RefreshMinInterval)
assert.Greater(t, refreshedData.Tokens.RefreshCooldownSeconds, int64(1))
}
func TestHandler_SessionRefresh_Disabled(t *testing.T) {
cfg := mock.Config()
cfg.Session.Refresh = false
idp := mock.NewIdentityProvider(cfg)
idp.ProviderHandler.TokenDuration = 5 * time.Second
defer idp.Close()
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := sessionRefresh(t, idp, rpClient)
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
}
func TestHandler_Default(t *testing.T) {
up := newUpstream(t)
defer up.Server.Close()
t.Run("without auto-login", func(t *testing.T) {
cfg := mock.Config()
cfg.UpstreamHost = up.URL.Host
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
up.SetReverseProxyUrl(idp.RelyingPartyServer.URL)
rpClient := idp.RelyingPartyClient()
// initial request without session
resp := get(t, rpClient, idp.RelyingPartyServer.URL)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, "not ok", resp.Body)
// acquire session
login(t, rpClient, idp)
// retry request with session
resp = get(t, rpClient, idp.RelyingPartyServer.URL)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "ok", resp.Body)
})
t.Run("with auto-login", func(t *testing.T) {
cfg := mock.Config()
cfg.AutoLogin = true
cfg.UpstreamHost = up.URL.Host
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
up.SetReverseProxyUrl(idp.RelyingPartyServer.URL)
rpClient := idp.RelyingPartyClient()
// initial request without session
target := idp.RelyingPartyServer.URL + "/"
resp := get(t, rpClient, target)
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
// redirect should point to local login endpoint
loginLocation := resp.Location
assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login?redirect-encoded="+urlpkg.RedirectEncoded("/"), loginLocation.String())
// follow redirect to local login endpoint
resp = get(t, rpClient, loginLocation.String())
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
// redirect should point to identity provider
authorizeLocation := resp.Location
authorizeEndpoint := *authorizeLocation
authorizeEndpoint.RawQuery = ""
assert.Equal(t, idp.OpenIDConfig.Provider().AuthorizationEndpoint(), authorizeEndpoint.String())
// follow redirect to identity provider for login
resp = get(t, rpClient, authorizeLocation.String())
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
// redirect should point back to relying party
callbackLocation := resp.Location
callbackEndpoint := *callbackLocation
callbackEndpoint.RawQuery = ""
req := idp.GetRequest(callbackLocation.String())
expectedCallbackURL, err := urlpkg.LoginCallbackURL(req)
assert.NoError(t, err)
assert.Equal(t, expectedCallbackURL, callbackEndpoint.String())
// follow redirect back to relying party
resp = get(t, rpClient, callbackLocation.String())
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
// finally, follow redirect back to original target, now with a session
targetLocation := resp.Location
assert.Equal(t, target, targetLocation.String())
resp = get(t, rpClient, targetLocation.String())
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "ok", resp.Body)
})
t.Run("with auto-login for non-GET requests", func(t *testing.T) {
for _, method := range []string{
http.MethodConnect,
http.MethodDelete,
http.MethodHead,
http.MethodOptions,
http.MethodPatch,
http.MethodPost,
http.MethodPut,
http.MethodTrace,
} {
t.Run(method, func(t *testing.T) {
cfg := mock.Config()
cfg.AutoLogin = true
cfg.UpstreamHost = up.URL.Host
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
up.SetReverseProxyUrl(idp.RelyingPartyServer.URL)
rpClient := idp.RelyingPartyClient()
req, err := http.NewRequest(method, idp.RelyingPartyServer.URL, nil)
assert.NoError(t, err)
resp, err := rpClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
}
})
t.Run("with auto-login and ignored paths", func(t *testing.T) {
cfg := mock.Config()
cfg.UpstreamHost = up.URL.Host
cfg.AutoLogin = true
cfg.AutoLoginIgnorePaths = []string{
"/exact/match",
"/allowed",
"/wildcard/*",
"/deeper/*/*",
"/any*",
}
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
up.SetReverseProxyUrl(idp.RelyingPartyServer.URL)
rpClient := idp.RelyingPartyClient()
t.Run("matched paths should not trigger login", func(t *testing.T) {
matched := []string{
"/exact/match",
"/allowed",
"/wildcard/",
"/wildcard/very",
"/deeper/1/",
"/deeper/1/2",
"/anything",
"/anywho",
}
for _, path := range matched {
t.Run(path, func(t *testing.T) {
target := idp.RelyingPartyServer.URL + path
resp := get(t, rpClient, target)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, "not ok", resp.Body)
})
}
})
t.Run("non-matched paths should trigger login", func(t *testing.T) {
nonMatched := []string{
"",
"/",
"/exact/match/",
"/exact/match/huh",
"/allowed/",
"/not-allowed",
"/not-allowed/allowed",
"/wildcard",
"/wildcard/yup/nope",
"/deeper",
"/deeper/",
"/deeper/1",
"/deeper/1/2/",
"/deeper/1/2/3",
"/any/",
"/anywho/stvent",
}
for _, path := range nonMatched {
t.Run(path, func(t *testing.T) {
target := idp.RelyingPartyServer.URL + path
resp := get(t, rpClient, target)
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
})
}
})
})
}
func localLogin(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
// First, run /oauth2/login to set cookies
loginURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/login")
assert.NoError(t, err)
resp := get(t, rpClient, loginURL.String())
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
cookies := rpClient.Jar.Cookies(loginURL)
sessionCookie := getCookieFromJar(cookie.Session, cookies)
loginCookie := getCookieFromJar(cookie.Login, cookies)
loginLegacyCookie := getCookieFromJar(cookie.LoginLegacy, cookies)
assert.Nil(t, sessionCookie)
assert.NotNil(t, loginCookie)
assert.NotNil(t, loginLegacyCookie)
return resp
}
func authorize(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
resp := localLogin(t, rpClient, idp)
// Follow redirect to authorize with identity provider
resp = get(t, rpClient, resp.Location.String())
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
return resp
}
func callback(t *testing.T, rpClient *http.Client, authorizeResponse response) *http.Cookie {
// Get callback URL after successful auth
callbackURL := authorizeResponse.Location
// Follow redirect to callback
resp := get(t, rpClient, callbackURL.String())
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
cookies := rpClient.Jar.Cookies(callbackURL)
sessionCookie := getCookieFromJar(cookie.Session, cookies)
loginCookie := getCookieFromJar(cookie.Login, cookies)
loginLegacyCookie := getCookieFromJar(cookie.LoginLegacy, cookies)
assert.NotNil(t, sessionCookie)
assert.Nil(t, loginCookie)
assert.Nil(t, loginLegacyCookie)
return sessionCookie
}
func login(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) *http.Cookie {
resp := authorize(t, rpClient, idp)
return callback(t, rpClient, resp)
}
func localLogout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
// Request self-initiated logout
logoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout")
assert.NoError(t, err)
resp := get(t, rpClient, logoutURL.String())
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
cookies := rpClient.Jar.Cookies(logoutURL)
sessionCookie := getCookieFromJar(cookie.Session, cookies)
assert.Nil(t, sessionCookie)
return resp
}
func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) {
// Get endsession endpoint after local logout
resp := localLogout(t, rpClient, idp)
// Follow redirect to endsession endpoint at identity provider
resp = get(t, rpClient, resp.Location.String())
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
// Get post-logout redirect URI after successful logout at identity provider
logoutCallbackURI := resp.Location
req := idp.GetRequest(resp.Location.String())
expectedLogoutCallbackURL, err := urlpkg.LogoutCallbackURL(req)
assert.NoError(t, err)
assert.Contains(t, logoutCallbackURI.String(), expectedLogoutCallbackURL)
assert.Equal(t, "/oauth2/logout/callback", logoutCallbackURI.Path)
// Follow redirect back to logout callback
resp = get(t, rpClient, logoutCallbackURI.String())
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
// Get post-logout redirect URI after redirect back to logout callback
assert.Equal(t, "https://google.com", resp.Location.String())
cookies := rpClient.Jar.Cookies(logoutCallbackURI)
sessionCookie := getCookieFromJar(cookie.Session, cookies)
assert.Nil(t, sessionCookie)
}
func sessionInfo(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) response {
sessionInfoURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/session")
assert.NoError(t, err)
return get(t, rpClient, sessionInfoURL.String())
}
func sessionRefresh(t *testing.T, idp *mock.IdentityProvider, rpClient *http.Client) response {
sessionRefreshURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/session/refresh")
assert.NoError(t, err)
return get(t, rpClient, sessionRefreshURL.String())
}
type response struct {
Body string
Location *url.URL
StatusCode int
}
func get(t *testing.T, client *http.Client, url string) response {
resp, err := client.Get(url)
assert.NoError(t, err)
defer resp.Body.Close()
location, err := resp.Location()
if !errors.Is(http.ErrNoLocation, err) {
assert.NoError(t, err)
}
body, err := io.ReadAll(resp.Body)
assert.NoError(t, err)
return response{
Body: string(body),
Location: location,
StatusCode: resp.StatusCode,
}
}
type upstream struct {
Server *httptest.Server
URL *url.URL
reverseProxyURL *url.URL
}
func (u *upstream) SetReverseProxyUrl(raw string) {
parsed, err := url.Parse(raw)
if err != nil {
panic(err)
}
u.reverseProxyURL = parsed
}
func newUpstream(t *testing.T) *upstream {
u := new(upstream)
upstreamHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
token := strings.TrimPrefix(authHeader, "Bearer ")
// Host should match the original authority from the ingress used to reach Wonderwall
assert.Equal(t, u.reverseProxyURL.Host, r.Host)
assert.NotEqual(t, u.URL.Host, r.Host)
if len(token) > 0 {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
} else {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte("not ok"))
}
})
server := httptest.NewServer(upstreamHandler)
upstreamURL, err := url.Parse(server.URL)
assert.NoError(t, err)
u.Server = server
u.URL = upstreamURL
return u
}
func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie {
for _, c := range cookies {
if c.Name == name {
return c
}
}
return nil
}