test(reverseproxy): extract common assertions

This commit is contained in:
Trong Huu Nguyen
2023-10-12 09:18:51 +02:00
parent b910d3e65a
commit c363bea556
2 changed files with 65 additions and 66 deletions

View File

@@ -680,7 +680,15 @@ type header struct {
}
func get(t *testing.T, client *http.Client, url string, headers ...header) response {
req, err := http.NewRequest(http.MethodGet, url, nil)
return request(t, client, http.MethodGet, url, headers...)
}
func post(t *testing.T, client *http.Client, url string) response {
return request(t, client, http.MethodPost, url)
}
func request(t *testing.T, client *http.Client, method, url string, headers ...header) response {
req, err := http.NewRequest(method, url, nil)
assert.NoError(t, err)
for _, h := range headers {
@@ -702,25 +710,6 @@ func get(t *testing.T, client *http.Client, url string, headers ...header) respo
}
}
func post(t *testing.T, client *http.Client, url string) response {
req, err := http.NewRequest(http.MethodPost, url, nil)
assert.NoError(t, err)
resp, err := client.Do(req)
assert.NoError(t, err)
location, err := resp.Location()
if !errors.Is(http.ErrNoLocation, err) {
assert.NoError(t, err)
}
return response{
Body: body(t, resp),
Location: location,
StatusCode: resp.StatusCode,
}
}
func body(t *testing.T, resp *http.Response) string {
defer resp.Body.Close()

View File

@@ -2,6 +2,7 @@ package handler_test
import (
"net/http"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
@@ -19,6 +20,38 @@ func TestReverseProxy(t *testing.T) {
up := newUpstream(t)
defer up.Server.Close()
loginURL := func(idp *mock.IdentityProvider, target string) string {
if target == "" {
return idp.RelyingPartyServer.URL + "/oauth2/login"
}
return idp.RelyingPartyServer.URL + "/oauth2/login?redirect=" + url.QueryEscape(target)
}
// assert that autologin intercepts the request and redirects to the login endpoint
assertAutoLoginRedirectResponse := func(t *testing.T, idp *mock.IdentityProvider, resp response, originalTarget string) {
assert.Equal(t, http.StatusFound, resp.StatusCode)
assert.Equal(t, loginURL(idp, originalTarget), resp.Location.String())
}
// assert that auto login intercepts the request and returns a 401 unauthorized
assertAutoLoginUnauthorizedResponse := func(t *testing.T, idp *mock.IdentityProvider, resp response, originalReferer string) {
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, loginURL(idp, originalReferer), resp.Location.String())
assert.Empty(t, resp.Body)
}
// assert that the request is proxied to the upstream, which returns a 401 unauthorized
assertUpstreamUnauthorizedResponse := func(t *testing.T, resp response) {
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, "not ok", resp.Body)
}
// assert that the request is proxied to the upstream, which returns a 200 ok
assertUpstreamOKResponse := func(t *testing.T, resp response) {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "ok", resp.Body)
}
t.Run("without auto-login", func(t *testing.T) {
cfg := mock.Config()
cfg.UpstreamHost = up.URL.Host
@@ -30,16 +63,14 @@ func TestReverseProxy(t *testing.T) {
// initial request without session
resp := get(t, rpClient, idp.RelyingPartyServer.URL)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, "not ok", resp.Body)
assertUpstreamUnauthorizedResponse(t, resp)
// acquire session
login(t, rpClient, idp)
// retry request with session
resp = get(t, rpClient, idp.RelyingPartyServer.URL)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "ok", resp.Body)
assertUpstreamOKResponse(t, resp)
})
t.Run("with auto-login", func(t *testing.T) {
@@ -56,19 +87,14 @@ func TestReverseProxy(t *testing.T) {
target := idp.RelyingPartyServer.URL + "/"
resp := get(t, rpClient, target, navigateFetchHeaders...)
assert.Equal(t, http.StatusFound, resp.StatusCode)
// redirect should point to local login endpoint
loginLocation := resp.Location
assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login?redirect=%2F", loginLocation.String())
assertAutoLoginRedirectResponse(t, idp, resp, "/")
// follow redirect to local login endpoint
resp = get(t, rpClient, loginLocation.String())
resp = get(t, rpClient, resp.Location.String())
assert.Equal(t, http.StatusFound, resp.StatusCode)
// redirect should point to identity provider
authorizeLocation := resp.Location
authorizeEndpoint := *authorizeLocation
authorizeEndpoint.RawQuery = ""
assert.Equal(t, idp.OpenIDConfig.Provider().AuthorizationEndpoint(), authorizeEndpoint.String())
@@ -79,7 +105,6 @@ func TestReverseProxy(t *testing.T) {
// redirect should point back to relying party
callbackLocation := resp.Location
callbackEndpoint := *callbackLocation
callbackEndpoint.RawQuery = ""
@@ -97,11 +122,10 @@ func TestReverseProxy(t *testing.T) {
assert.Equal(t, target, targetLocation.String())
resp = get(t, rpClient, targetLocation.String())
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "ok", resp.Body)
assertUpstreamOKResponse(t, resp)
})
t.Run("with auto-login for non-GET requests", func(t *testing.T) {
t.Run("with auto-login for non-GET requests returns 401 unauthorized", func(t *testing.T) {
for _, method := range []string{
http.MethodConnect,
http.MethodDelete,
@@ -122,23 +146,13 @@ func TestReverseProxy(t *testing.T) {
up.SetIdentityProvider(idp)
rpClient := idp.RelyingPartyClient()
req, err := http.NewRequest(method, idp.RelyingPartyServer.URL, nil)
assert.NoError(t, err)
resp, err := rpClient.Do(req)
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
location, err := resp.Location()
assert.NoError(t, err)
assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login", location.String())
resp := request(t, rpClient, method, idp.RelyingPartyServer.URL)
assertAutoLoginUnauthorizedResponse(t, idp, resp, "")
})
}
})
t.Run("with auto-login for non-navigation requests", func(t *testing.T) {
t.Run("with auto-login for non-navigation requests returns 401 unauthorized", func(t *testing.T) {
cfg := mock.Config()
cfg.AutoLogin = true
cfg.UpstreamHost = up.URL.Host
@@ -149,13 +163,16 @@ func TestReverseProxy(t *testing.T) {
rpClient := idp.RelyingPartyClient()
target := idp.RelyingPartyServer.URL + "/"
resp := get(t, rpClient, target)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login", resp.Location.String())
assertAutoLoginUnauthorizedResponse(t, idp, resp, "")
referer := idp.RelyingPartyServer.URL + "/some-path"
target = idp.RelyingPartyServer.URL + "/some-path/resource"
resp = get(t, rpClient, target, header{"Referer", referer})
assertAutoLoginUnauthorizedResponse(t, idp, resp, referer)
})
t.Run("with auto-login for navigation request without fetch metadata", func(t *testing.T) {
t.Run("with auto-login for navigation request without fetch metadata returns 3xx redirect", func(t *testing.T) {
cfg := mock.Config()
cfg.AutoLogin = true
cfg.UpstreamHost = up.URL.Host
@@ -165,15 +182,14 @@ func TestReverseProxy(t *testing.T) {
up.SetIdentityProvider(idp)
rpClient := idp.RelyingPartyClient()
target := idp.RelyingPartyServer.URL + "/"
target := idp.RelyingPartyServer.URL + "/some-path"
resp := get(t, rpClient, target,
header{"Sec-Fetch-Mode", ""},
header{"Sec-Fetch-Dest", ""},
header{"Accept", "text/html"},
)
assert.Equal(t, http.StatusFound, resp.StatusCode)
assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login?redirect=%2F", resp.Location.String())
assertAutoLoginRedirectResponse(t, idp, resp, "/some-path")
})
t.Run("with auto-login and ignored paths", func(t *testing.T) {
@@ -308,9 +324,7 @@ func TestReverseProxy(t *testing.T) {
t.Run(path, func(t *testing.T) {
target := idp.RelyingPartyServer.URL + path
resp := get(t, rpClient, target, navigateFetchHeaders...)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, "not ok", resp.Body)
assertUpstreamUnauthorizedResponse(t, resp)
})
}
})
@@ -320,8 +334,7 @@ func TestReverseProxy(t *testing.T) {
t.Run(path, func(t *testing.T) {
target := idp.RelyingPartyServer.URL + path
resp := get(t, rpClient, target, navigateFetchHeaders...)
assert.Equal(t, http.StatusFound, resp.StatusCode)
assertAutoLoginRedirectResponse(t, idp, resp, path)
})
}
})
@@ -345,8 +358,7 @@ func TestReverseProxy(t *testing.T) {
}
resp := get(t, rpClient, idp.RelyingPartyServer.URL, header{"Authorization", "Bearer some-authorization"})
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
assert.Equal(t, "not ok", resp.Body)
assertUpstreamUnauthorizedResponse(t, resp)
})
t.Run("should be overwritten if session found", func(t *testing.T) {
@@ -359,8 +371,7 @@ func TestReverseProxy(t *testing.T) {
}
resp := get(t, rpClient, idp.RelyingPartyServer.URL, header{"Authorization", "Bearer some-authorization"})
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "ok", resp.Body)
assertUpstreamOKResponse(t, resp)
})
})
@@ -390,7 +401,6 @@ func TestReverseProxy(t *testing.T) {
{"X-Forwarded-Host", "wonderwall.example"},
{"X-Forwarded-Proto", "https"},
}...)
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "ok", resp.Body)
assertUpstreamOKResponse(t, resp)
})
}