mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-07 00:46:56 +00:00
refactor(handler/test): extract upstream and httpclient for readability
This commit is contained in:
@@ -3,6 +3,7 @@ package handler_test
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
@@ -25,10 +26,7 @@ func TestHandler_Login(t *testing.T) {
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
|
||||
resp := localLogin(t, rpClient, idp)
|
||||
defer resp.Body.Close()
|
||||
|
||||
loginURL, err := resp.Location()
|
||||
assert.NoError(t, err)
|
||||
loginURL := resp.Location
|
||||
|
||||
assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host))
|
||||
assert.Equal(t, "/authorize", loginURL.Path)
|
||||
@@ -42,14 +40,10 @@ func TestHandler_Login(t *testing.T) {
|
||||
assert.NotEmpty(t, loginURL.Query().Get("nonce"))
|
||||
assert.NotEmpty(t, loginURL.Query().Get("code_challenge"))
|
||||
|
||||
resp, err = rpClient.Get(loginURL.String())
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
resp = get(t, rpClient, loginURL.String())
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
callbackURL, err := resp.Location()
|
||||
assert.NoError(t, err)
|
||||
|
||||
callbackURL := resp.Location
|
||||
assert.Equal(t, loginURL.Query().Get("state"), callbackURL.Query().Get("state"))
|
||||
assert.NotEmpty(t, callbackURL.Query().Get("code"))
|
||||
}
|
||||
@@ -72,11 +66,9 @@ func TestHandler_Logout(t *testing.T) {
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := localLogout(t, rpClient, idp)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Get endsession endpoint after local logout
|
||||
endsessionURL, err := resp.Location()
|
||||
assert.NoError(t, err)
|
||||
endsessionURL := resp.Location
|
||||
|
||||
idpserverURL, err := url.Parse(idp.ProviderServer.URL)
|
||||
assert.NoError(t, err)
|
||||
@@ -133,9 +125,7 @@ func TestHandler_FrontChannelLogout(t *testing.T) {
|
||||
values.Add("iss", idp.OpenIDConfig.Provider().Issuer)
|
||||
frontchannelLogoutURL.RawQuery = values.Encode()
|
||||
|
||||
resp, err := rpClient.Get(frontchannelLogoutURL.String())
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
resp := get(t, rpClient, frontchannelLogoutURL.String())
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
@@ -148,72 +138,43 @@ func TestHandler_SessionStateRequired(t *testing.T) {
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
|
||||
resp := authorize(t, rpClient, idp)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Get callback URL after successful auth
|
||||
callbackURL, err := resp.Location()
|
||||
assert.NoError(t, err)
|
||||
|
||||
params := callbackURL.Query()
|
||||
params := resp.Location.Query()
|
||||
sessionState := params.Get("session_state")
|
||||
assert.NotEmpty(t, sessionState)
|
||||
}
|
||||
|
||||
func TestHandler_Default(t *testing.T) {
|
||||
upstreamHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
|
||||
if len(token) > 0 {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("not ok"))
|
||||
}
|
||||
})
|
||||
upstream := httptest.NewServer(upstreamHandler)
|
||||
defer upstream.Close()
|
||||
|
||||
upstreamURL, err := url.Parse(upstream.URL)
|
||||
assert.NoError(t, err)
|
||||
up := newUpstream(t)
|
||||
defer up.Server.Close()
|
||||
|
||||
t.Run("without auto-login", func(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.UpstreamHost = upstreamURL.Host
|
||||
cfg.UpstreamHost = up.URL.Host
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
|
||||
// initial request without session
|
||||
resp, err := rpClient.Get(idp.RelyingPartyServer.URL)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
resp := get(t, rpClient, idp.RelyingPartyServer.URL)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "not ok", string(body))
|
||||
assert.Equal(t, "not ok", resp.Body)
|
||||
|
||||
// acquire session
|
||||
login(t, rpClient, idp)
|
||||
|
||||
// retry request with session
|
||||
resp, err = rpClient.Get(idp.RelyingPartyServer.URL)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
resp = get(t, rpClient, idp.RelyingPartyServer.URL)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
body, err = ioutil.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ok", string(body))
|
||||
assert.Equal(t, "ok", resp.Body)
|
||||
})
|
||||
|
||||
t.Run("with auto-login", func(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.AutoLogin = true
|
||||
cfg.UpstreamHost = upstreamURL.Host
|
||||
cfg.UpstreamHost = up.URL.Host
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
@@ -222,55 +183,43 @@ func TestHandler_Default(t *testing.T) {
|
||||
// initial request without session
|
||||
target := idp.RelyingPartyServer.URL + "/"
|
||||
|
||||
resp, err := rpClient.Get(target)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
resp := get(t, rpClient, target)
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
// redirect should point to identity provider
|
||||
authorizeLocation, err := resp.Location()
|
||||
assert.NoError(t, err)
|
||||
authorizeLocation := resp.Location
|
||||
|
||||
authorizeEndpoint := *authorizeLocation
|
||||
authorizeEndpoint.RawQuery = ""
|
||||
assert.Equal(t, idp.OpenIDConfig.Provider().AuthorizationEndpoint, authorizeEndpoint.String())
|
||||
|
||||
// follow redirect to identity provider for login
|
||||
resp, err = rpClient.Get(authorizeLocation.String())
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
resp = get(t, rpClient, authorizeLocation.String())
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
// redirect should point back to relying party
|
||||
callbackLocation, err := resp.Location()
|
||||
assert.NoError(t, err)
|
||||
callbackLocation := resp.Location
|
||||
|
||||
callbackEndpoint := *callbackLocation
|
||||
callbackEndpoint.RawQuery = ""
|
||||
assert.Equal(t, idp.OpenIDConfig.Client().GetCallbackURI(), callbackEndpoint.String())
|
||||
|
||||
// follow redirect back to relying party
|
||||
resp, err = rpClient.Get(callbackLocation.String())
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
resp = get(t, rpClient, callbackLocation.String())
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
// finally, follow redirect back to original target, now with a session
|
||||
targetLocation, err := resp.Location()
|
||||
assert.NoError(t, err)
|
||||
targetLocation := resp.Location
|
||||
assert.Equal(t, target, targetLocation.String())
|
||||
|
||||
resp, err = rpClient.Get(targetLocation.String())
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
resp = get(t, rpClient, targetLocation.String())
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "ok", string(body))
|
||||
assert.Equal(t, "ok", resp.Body)
|
||||
})
|
||||
|
||||
t.Run("with auto-login and skipped paths", func(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.UpstreamHost = upstreamURL.Host
|
||||
cfg.UpstreamHost = up.URL.Host
|
||||
cfg.AutoLogin = true
|
||||
cfg.AutoLoginSkipPaths = []string{
|
||||
"^/exact/match$",
|
||||
@@ -300,15 +249,11 @@ func TestHandler_Default(t *testing.T) {
|
||||
}
|
||||
for _, path := range matched {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
resp, err := rpClient.Get(idp.RelyingPartyServer.URL + path)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
target := idp.RelyingPartyServer.URL + path
|
||||
resp := get(t, rpClient, target)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "not ok", string(body))
|
||||
assert.Equal(t, "not ok", resp.Body)
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -327,9 +272,8 @@ func TestHandler_Default(t *testing.T) {
|
||||
}
|
||||
for _, path := range nonMatched {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
resp, err := rpClient.Get(idp.RelyingPartyServer.URL + path)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
target := idp.RelyingPartyServer.URL + path
|
||||
resp := get(t, rpClient, target)
|
||||
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
})
|
||||
@@ -338,13 +282,12 @@ func TestHandler_Default(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func localLogin(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) *http.Response {
|
||||
func localLogin(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) response {
|
||||
// First, run /oauth2/login to set cookies
|
||||
loginURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/login")
|
||||
assert.NoError(t, err)
|
||||
|
||||
resp, err := rpClient.Get(loginURL.String())
|
||||
assert.NoError(t, err)
|
||||
resp := get(t, rpClient, loginURL.String())
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
cookies := rpClient.Jar.Cookies(loginURL)
|
||||
@@ -359,30 +302,22 @@ func localLogin(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider)
|
||||
return resp
|
||||
}
|
||||
|
||||
func authorize(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) *http.Response {
|
||||
func authorize(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) response {
|
||||
resp := localLogin(t, rpClient, idp)
|
||||
defer resp.Body.Close()
|
||||
|
||||
authorizeURL, err := resp.Location()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Follow redirect to authorize with identity provider
|
||||
resp, err = rpClient.Get(authorizeURL.String())
|
||||
assert.NoError(t, err)
|
||||
resp = get(t, rpClient, resp.Location.String())
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func callback(t *testing.T, rpClient *http.Client, authorizeResponse *http.Response) *http.Cookie {
|
||||
func callback(t *testing.T, rpClient *http.Client, authorizeResponse response) *http.Cookie {
|
||||
// Get callback URL after successful auth
|
||||
callbackURL, err := authorizeResponse.Location()
|
||||
assert.NoError(t, err)
|
||||
callbackURL := authorizeResponse.Location
|
||||
|
||||
// Follow redirect to callback
|
||||
resp, err := rpClient.Get(callbackURL.String())
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
resp := get(t, rpClient, callbackURL.String())
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
cookies := rpClient.Jar.Cookies(callbackURL)
|
||||
@@ -399,18 +334,15 @@ func callback(t *testing.T, rpClient *http.Client, authorizeResponse *http.Respo
|
||||
|
||||
func login(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) *http.Cookie {
|
||||
resp := authorize(t, rpClient, idp)
|
||||
defer resp.Body.Close()
|
||||
return callback(t, rpClient, resp)
|
||||
}
|
||||
|
||||
func localLogout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) *http.Response {
|
||||
func localLogout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) response {
|
||||
// Request self-initiated logout
|
||||
logoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout")
|
||||
assert.NoError(t, err)
|
||||
|
||||
resp, err := rpClient.Get(logoutURL.String())
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
resp := get(t, rpClient, logoutURL.String())
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
cookies := rpClient.Jar.Cookies(logoutURL)
|
||||
@@ -422,36 +354,24 @@ func localLogout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider)
|
||||
}
|
||||
|
||||
func logout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) {
|
||||
resp := localLogout(t, rpClient, idp)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Get endsession endpoint after local logout
|
||||
endsessionURL, err := resp.Location()
|
||||
assert.NoError(t, err)
|
||||
resp := localLogout(t, rpClient, idp)
|
||||
|
||||
// Follow redirect to endsession endpoint at identity provider
|
||||
resp, err = rpClient.Get(endsessionURL.String())
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
resp = get(t, rpClient, resp.Location.String())
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
// Get post-logout redirect URI after successful logout at identity provider
|
||||
logoutCallbackURI, err := resp.Location()
|
||||
assert.NoError(t, err)
|
||||
logoutCallbackURI := resp.Location
|
||||
assert.Contains(t, logoutCallbackURI.String(), idp.OpenIDConfig.Client().GetLogoutCallbackURI())
|
||||
|
||||
assert.Equal(t, "/oauth2/logout/callback", logoutCallbackURI.Path)
|
||||
|
||||
// Follow redirect back to logout callback
|
||||
resp, err = rpClient.Get(logoutCallbackURI.String())
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
resp = get(t, rpClient, logoutCallbackURI.String())
|
||||
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
||||
|
||||
// Get post-logout redirect URI after redirect back to logout callback
|
||||
postLogoutRedirectURI, err := resp.Location()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, idp.OpenIDConfig.Client().GetPostLogoutRedirectURI(), postLogoutRedirectURI.String())
|
||||
assert.Equal(t, idp.OpenIDConfig.Client().GetPostLogoutRedirectURI(), resp.Location.String())
|
||||
|
||||
cookies := rpClient.Jar.Cookies(logoutCallbackURI)
|
||||
sessionCookie := getCookieFromJar(cookie.Session, cookies)
|
||||
@@ -459,6 +379,61 @@ func logout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) {
|
||||
assert.Nil(t, sessionCookie)
|
||||
}
|
||||
|
||||
type response struct {
|
||||
Body string
|
||||
Location *url.URL
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
func get(t *testing.T, client *http.Client, url string) response {
|
||||
resp, err := client.Get(url)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
location, err := resp.Location()
|
||||
if !errors.Is(http.ErrNoLocation, err) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return response{
|
||||
Body: string(body),
|
||||
Location: location,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
type upstream struct {
|
||||
Server *httptest.Server
|
||||
URL *url.URL
|
||||
}
|
||||
|
||||
func newUpstream(t *testing.T) upstream {
|
||||
upstreamHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
|
||||
if len(token) > 0 {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
} else {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("not ok"))
|
||||
}
|
||||
})
|
||||
server := httptest.NewServer(upstreamHandler)
|
||||
|
||||
upstreamURL, err := url.Parse(server.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return upstream{
|
||||
Server: server,
|
||||
URL: upstreamURL,
|
||||
}
|
||||
}
|
||||
|
||||
func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie {
|
||||
for _, c := range cookies {
|
||||
if c.Name == name {
|
||||
|
||||
Reference in New Issue
Block a user