mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-11 02:47:05 +00:00
test(reverseproxy): extract common assertions
This commit is contained in:
@@ -680,7 +680,15 @@ type header struct {
|
||||
}
|
||||
|
||||
func get(t *testing.T, client *http.Client, url string, headers ...header) response {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
return request(t, client, http.MethodGet, url, headers...)
|
||||
}
|
||||
|
||||
func post(t *testing.T, client *http.Client, url string) response {
|
||||
return request(t, client, http.MethodPost, url)
|
||||
}
|
||||
|
||||
func request(t *testing.T, client *http.Client, method, url string, headers ...header) response {
|
||||
req, err := http.NewRequest(method, url, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, h := range headers {
|
||||
@@ -702,25 +710,6 @@ func get(t *testing.T, client *http.Client, url string, headers ...header) respo
|
||||
}
|
||||
}
|
||||
|
||||
func post(t *testing.T, client *http.Client, url string) response {
|
||||
req, err := http.NewRequest(http.MethodPost, url, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
location, err := resp.Location()
|
||||
if !errors.Is(http.ErrNoLocation, err) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
return response{
|
||||
Body: body(t, resp),
|
||||
Location: location,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
}
|
||||
|
||||
func body(t *testing.T, resp *http.Response) string {
|
||||
defer resp.Body.Close()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package handler_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -19,6 +20,38 @@ func TestReverseProxy(t *testing.T) {
|
||||
up := newUpstream(t)
|
||||
defer up.Server.Close()
|
||||
|
||||
loginURL := func(idp *mock.IdentityProvider, target string) string {
|
||||
if target == "" {
|
||||
return idp.RelyingPartyServer.URL + "/oauth2/login"
|
||||
}
|
||||
return idp.RelyingPartyServer.URL + "/oauth2/login?redirect=" + url.QueryEscape(target)
|
||||
}
|
||||
|
||||
// assert that autologin intercepts the request and redirects to the login endpoint
|
||||
assertAutoLoginRedirectResponse := func(t *testing.T, idp *mock.IdentityProvider, resp response, originalTarget string) {
|
||||
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
||||
assert.Equal(t, loginURL(idp, originalTarget), resp.Location.String())
|
||||
}
|
||||
|
||||
// assert that auto login intercepts the request and returns a 401 unauthorized
|
||||
assertAutoLoginUnauthorizedResponse := func(t *testing.T, idp *mock.IdentityProvider, resp response, originalReferer string) {
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
assert.Equal(t, loginURL(idp, originalReferer), resp.Location.String())
|
||||
assert.Empty(t, resp.Body)
|
||||
}
|
||||
|
||||
// assert that the request is proxied to the upstream, which returns a 401 unauthorized
|
||||
assertUpstreamUnauthorizedResponse := func(t *testing.T, resp response) {
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
assert.Equal(t, "not ok", resp.Body)
|
||||
}
|
||||
|
||||
// assert that the request is proxied to the upstream, which returns a 200 ok
|
||||
assertUpstreamOKResponse := func(t *testing.T, resp response) {
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "ok", resp.Body)
|
||||
}
|
||||
|
||||
t.Run("without auto-login", func(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.UpstreamHost = up.URL.Host
|
||||
@@ -30,16 +63,14 @@ func TestReverseProxy(t *testing.T) {
|
||||
|
||||
// initial request without session
|
||||
resp := get(t, rpClient, idp.RelyingPartyServer.URL)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
assert.Equal(t, "not ok", resp.Body)
|
||||
assertUpstreamUnauthorizedResponse(t, resp)
|
||||
|
||||
// acquire session
|
||||
login(t, rpClient, idp)
|
||||
|
||||
// retry request with session
|
||||
resp = get(t, rpClient, idp.RelyingPartyServer.URL)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "ok", resp.Body)
|
||||
assertUpstreamOKResponse(t, resp)
|
||||
})
|
||||
|
||||
t.Run("with auto-login", func(t *testing.T) {
|
||||
@@ -56,19 +87,14 @@ func TestReverseProxy(t *testing.T) {
|
||||
target := idp.RelyingPartyServer.URL + "/"
|
||||
|
||||
resp := get(t, rpClient, target, navigateFetchHeaders...)
|
||||
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
||||
|
||||
// redirect should point to local login endpoint
|
||||
loginLocation := resp.Location
|
||||
assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login?redirect=%2F", loginLocation.String())
|
||||
assertAutoLoginRedirectResponse(t, idp, resp, "/")
|
||||
|
||||
// follow redirect to local login endpoint
|
||||
resp = get(t, rpClient, loginLocation.String())
|
||||
resp = get(t, rpClient, resp.Location.String())
|
||||
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
||||
|
||||
// redirect should point to identity provider
|
||||
authorizeLocation := resp.Location
|
||||
|
||||
authorizeEndpoint := *authorizeLocation
|
||||
authorizeEndpoint.RawQuery = ""
|
||||
assert.Equal(t, idp.OpenIDConfig.Provider().AuthorizationEndpoint(), authorizeEndpoint.String())
|
||||
@@ -79,7 +105,6 @@ func TestReverseProxy(t *testing.T) {
|
||||
|
||||
// redirect should point back to relying party
|
||||
callbackLocation := resp.Location
|
||||
|
||||
callbackEndpoint := *callbackLocation
|
||||
callbackEndpoint.RawQuery = ""
|
||||
|
||||
@@ -97,11 +122,10 @@ func TestReverseProxy(t *testing.T) {
|
||||
assert.Equal(t, target, targetLocation.String())
|
||||
|
||||
resp = get(t, rpClient, targetLocation.String())
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "ok", resp.Body)
|
||||
assertUpstreamOKResponse(t, resp)
|
||||
})
|
||||
|
||||
t.Run("with auto-login for non-GET requests", func(t *testing.T) {
|
||||
t.Run("with auto-login for non-GET requests returns 401 unauthorized", func(t *testing.T) {
|
||||
for _, method := range []string{
|
||||
http.MethodConnect,
|
||||
http.MethodDelete,
|
||||
@@ -122,23 +146,13 @@ func TestReverseProxy(t *testing.T) {
|
||||
up.SetIdentityProvider(idp)
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
|
||||
req, err := http.NewRequest(method, idp.RelyingPartyServer.URL, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
resp, err := rpClient.Do(req)
|
||||
assert.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
|
||||
location, err := resp.Location()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login", location.String())
|
||||
resp := request(t, rpClient, method, idp.RelyingPartyServer.URL)
|
||||
assertAutoLoginUnauthorizedResponse(t, idp, resp, "")
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with auto-login for non-navigation requests", func(t *testing.T) {
|
||||
t.Run("with auto-login for non-navigation requests returns 401 unauthorized", func(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.AutoLogin = true
|
||||
cfg.UpstreamHost = up.URL.Host
|
||||
@@ -149,13 +163,16 @@ func TestReverseProxy(t *testing.T) {
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
|
||||
target := idp.RelyingPartyServer.URL + "/"
|
||||
|
||||
resp := get(t, rpClient, target)
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login", resp.Location.String())
|
||||
assertAutoLoginUnauthorizedResponse(t, idp, resp, "")
|
||||
|
||||
referer := idp.RelyingPartyServer.URL + "/some-path"
|
||||
target = idp.RelyingPartyServer.URL + "/some-path/resource"
|
||||
resp = get(t, rpClient, target, header{"Referer", referer})
|
||||
assertAutoLoginUnauthorizedResponse(t, idp, resp, referer)
|
||||
})
|
||||
|
||||
t.Run("with auto-login for navigation request without fetch metadata", func(t *testing.T) {
|
||||
t.Run("with auto-login for navigation request without fetch metadata returns 3xx redirect", func(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.AutoLogin = true
|
||||
cfg.UpstreamHost = up.URL.Host
|
||||
@@ -165,15 +182,14 @@ func TestReverseProxy(t *testing.T) {
|
||||
up.SetIdentityProvider(idp)
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
|
||||
target := idp.RelyingPartyServer.URL + "/"
|
||||
target := idp.RelyingPartyServer.URL + "/some-path"
|
||||
|
||||
resp := get(t, rpClient, target,
|
||||
header{"Sec-Fetch-Mode", ""},
|
||||
header{"Sec-Fetch-Dest", ""},
|
||||
header{"Accept", "text/html"},
|
||||
)
|
||||
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
||||
assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login?redirect=%2F", resp.Location.String())
|
||||
assertAutoLoginRedirectResponse(t, idp, resp, "/some-path")
|
||||
})
|
||||
|
||||
t.Run("with auto-login and ignored paths", func(t *testing.T) {
|
||||
@@ -308,9 +324,7 @@ func TestReverseProxy(t *testing.T) {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
target := idp.RelyingPartyServer.URL + path
|
||||
resp := get(t, rpClient, target, navigateFetchHeaders...)
|
||||
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
assert.Equal(t, "not ok", resp.Body)
|
||||
assertUpstreamUnauthorizedResponse(t, resp)
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -320,8 +334,7 @@ func TestReverseProxy(t *testing.T) {
|
||||
t.Run(path, func(t *testing.T) {
|
||||
target := idp.RelyingPartyServer.URL + path
|
||||
resp := get(t, rpClient, target, navigateFetchHeaders...)
|
||||
|
||||
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
||||
assertAutoLoginRedirectResponse(t, idp, resp, path)
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -345,8 +358,7 @@ func TestReverseProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
resp := get(t, rpClient, idp.RelyingPartyServer.URL, header{"Authorization", "Bearer some-authorization"})
|
||||
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||
assert.Equal(t, "not ok", resp.Body)
|
||||
assertUpstreamUnauthorizedResponse(t, resp)
|
||||
})
|
||||
|
||||
t.Run("should be overwritten if session found", func(t *testing.T) {
|
||||
@@ -359,8 +371,7 @@ func TestReverseProxy(t *testing.T) {
|
||||
}
|
||||
|
||||
resp := get(t, rpClient, idp.RelyingPartyServer.URL, header{"Authorization", "Bearer some-authorization"})
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "ok", resp.Body)
|
||||
assertUpstreamOKResponse(t, resp)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -390,7 +401,6 @@ func TestReverseProxy(t *testing.T) {
|
||||
{"X-Forwarded-Host", "wonderwall.example"},
|
||||
{"X-Forwarded-Proto", "https"},
|
||||
}...)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
assert.Equal(t, "ok", resp.Body)
|
||||
assertUpstreamOKResponse(t, resp)
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user