From c363bea556b6b22e52a03547bf810c378c175fdc Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Thu, 12 Oct 2023 09:18:51 +0200 Subject: [PATCH] test(reverseproxy): extract common assertions --- pkg/handler/handler_test.go | 29 +++------ pkg/handler/reverseproxy_test.go | 102 +++++++++++++++++-------------- 2 files changed, 65 insertions(+), 66 deletions(-) diff --git a/pkg/handler/handler_test.go b/pkg/handler/handler_test.go index 92d4ce5..b87f9dd 100644 --- a/pkg/handler/handler_test.go +++ b/pkg/handler/handler_test.go @@ -680,7 +680,15 @@ type header struct { } func get(t *testing.T, client *http.Client, url string, headers ...header) response { - req, err := http.NewRequest(http.MethodGet, url, nil) + return request(t, client, http.MethodGet, url, headers...) +} + +func post(t *testing.T, client *http.Client, url string) response { + return request(t, client, http.MethodPost, url) +} + +func request(t *testing.T, client *http.Client, method, url string, headers ...header) response { + req, err := http.NewRequest(method, url, nil) assert.NoError(t, err) for _, h := range headers { @@ -702,25 +710,6 @@ func get(t *testing.T, client *http.Client, url string, headers ...header) respo } } -func post(t *testing.T, client *http.Client, url string) response { - req, err := http.NewRequest(http.MethodPost, url, nil) - assert.NoError(t, err) - - resp, err := client.Do(req) - assert.NoError(t, err) - - location, err := resp.Location() - if !errors.Is(http.ErrNoLocation, err) { - assert.NoError(t, err) - } - - return response{ - Body: body(t, resp), - Location: location, - StatusCode: resp.StatusCode, - } -} - func body(t *testing.T, resp *http.Response) string { defer resp.Body.Close() diff --git a/pkg/handler/reverseproxy_test.go b/pkg/handler/reverseproxy_test.go index 2fec233..f0a3a90 100644 --- a/pkg/handler/reverseproxy_test.go +++ b/pkg/handler/reverseproxy_test.go @@ -2,6 +2,7 @@ package handler_test import ( "net/http" + "net/url" "testing" "github.com/stretchr/testify/assert" @@ -19,6 +20,38 @@ func TestReverseProxy(t *testing.T) { up := newUpstream(t) defer up.Server.Close() + loginURL := func(idp *mock.IdentityProvider, target string) string { + if target == "" { + return idp.RelyingPartyServer.URL + "/oauth2/login" + } + return idp.RelyingPartyServer.URL + "/oauth2/login?redirect=" + url.QueryEscape(target) + } + + // assert that autologin intercepts the request and redirects to the login endpoint + assertAutoLoginRedirectResponse := func(t *testing.T, idp *mock.IdentityProvider, resp response, originalTarget string) { + assert.Equal(t, http.StatusFound, resp.StatusCode) + assert.Equal(t, loginURL(idp, originalTarget), resp.Location.String()) + } + + // assert that auto login intercepts the request and returns a 401 unauthorized + assertAutoLoginUnauthorizedResponse := func(t *testing.T, idp *mock.IdentityProvider, resp response, originalReferer string) { + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.Equal(t, loginURL(idp, originalReferer), resp.Location.String()) + assert.Empty(t, resp.Body) + } + + // assert that the request is proxied to the upstream, which returns a 401 unauthorized + assertUpstreamUnauthorizedResponse := func(t *testing.T, resp response) { + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.Equal(t, "not ok", resp.Body) + } + + // assert that the request is proxied to the upstream, which returns a 200 ok + assertUpstreamOKResponse := func(t *testing.T, resp response) { + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "ok", resp.Body) + } + t.Run("without auto-login", func(t *testing.T) { cfg := mock.Config() cfg.UpstreamHost = up.URL.Host @@ -30,16 +63,14 @@ func TestReverseProxy(t *testing.T) { // initial request without session resp := get(t, rpClient, idp.RelyingPartyServer.URL) - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - assert.Equal(t, "not ok", resp.Body) + assertUpstreamUnauthorizedResponse(t, resp) // acquire session login(t, rpClient, idp) // retry request with session resp = get(t, rpClient, idp.RelyingPartyServer.URL) - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, "ok", resp.Body) + assertUpstreamOKResponse(t, resp) }) t.Run("with auto-login", func(t *testing.T) { @@ -56,19 +87,14 @@ func TestReverseProxy(t *testing.T) { target := idp.RelyingPartyServer.URL + "/" resp := get(t, rpClient, target, navigateFetchHeaders...) - assert.Equal(t, http.StatusFound, resp.StatusCode) - - // redirect should point to local login endpoint - loginLocation := resp.Location - assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login?redirect=%2F", loginLocation.String()) + assertAutoLoginRedirectResponse(t, idp, resp, "/") // follow redirect to local login endpoint - resp = get(t, rpClient, loginLocation.String()) + resp = get(t, rpClient, resp.Location.String()) assert.Equal(t, http.StatusFound, resp.StatusCode) // redirect should point to identity provider authorizeLocation := resp.Location - authorizeEndpoint := *authorizeLocation authorizeEndpoint.RawQuery = "" assert.Equal(t, idp.OpenIDConfig.Provider().AuthorizationEndpoint(), authorizeEndpoint.String()) @@ -79,7 +105,6 @@ func TestReverseProxy(t *testing.T) { // redirect should point back to relying party callbackLocation := resp.Location - callbackEndpoint := *callbackLocation callbackEndpoint.RawQuery = "" @@ -97,11 +122,10 @@ func TestReverseProxy(t *testing.T) { assert.Equal(t, target, targetLocation.String()) resp = get(t, rpClient, targetLocation.String()) - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, "ok", resp.Body) + assertUpstreamOKResponse(t, resp) }) - t.Run("with auto-login for non-GET requests", func(t *testing.T) { + t.Run("with auto-login for non-GET requests returns 401 unauthorized", func(t *testing.T) { for _, method := range []string{ http.MethodConnect, http.MethodDelete, @@ -122,23 +146,13 @@ func TestReverseProxy(t *testing.T) { up.SetIdentityProvider(idp) rpClient := idp.RelyingPartyClient() - req, err := http.NewRequest(method, idp.RelyingPartyServer.URL, nil) - assert.NoError(t, err) - - resp, err := rpClient.Do(req) - assert.NoError(t, err) - defer resp.Body.Close() - - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - - location, err := resp.Location() - assert.NoError(t, err) - assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login", location.String()) + resp := request(t, rpClient, method, idp.RelyingPartyServer.URL) + assertAutoLoginUnauthorizedResponse(t, idp, resp, "") }) } }) - t.Run("with auto-login for non-navigation requests", func(t *testing.T) { + t.Run("with auto-login for non-navigation requests returns 401 unauthorized", func(t *testing.T) { cfg := mock.Config() cfg.AutoLogin = true cfg.UpstreamHost = up.URL.Host @@ -149,13 +163,16 @@ func TestReverseProxy(t *testing.T) { rpClient := idp.RelyingPartyClient() target := idp.RelyingPartyServer.URL + "/" - resp := get(t, rpClient, target) - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login", resp.Location.String()) + assertAutoLoginUnauthorizedResponse(t, idp, resp, "") + + referer := idp.RelyingPartyServer.URL + "/some-path" + target = idp.RelyingPartyServer.URL + "/some-path/resource" + resp = get(t, rpClient, target, header{"Referer", referer}) + assertAutoLoginUnauthorizedResponse(t, idp, resp, referer) }) - t.Run("with auto-login for navigation request without fetch metadata", func(t *testing.T) { + t.Run("with auto-login for navigation request without fetch metadata returns 3xx redirect", func(t *testing.T) { cfg := mock.Config() cfg.AutoLogin = true cfg.UpstreamHost = up.URL.Host @@ -165,15 +182,14 @@ func TestReverseProxy(t *testing.T) { up.SetIdentityProvider(idp) rpClient := idp.RelyingPartyClient() - target := idp.RelyingPartyServer.URL + "/" + target := idp.RelyingPartyServer.URL + "/some-path" resp := get(t, rpClient, target, header{"Sec-Fetch-Mode", ""}, header{"Sec-Fetch-Dest", ""}, header{"Accept", "text/html"}, ) - assert.Equal(t, http.StatusFound, resp.StatusCode) - assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login?redirect=%2F", resp.Location.String()) + assertAutoLoginRedirectResponse(t, idp, resp, "/some-path") }) t.Run("with auto-login and ignored paths", func(t *testing.T) { @@ -308,9 +324,7 @@ func TestReverseProxy(t *testing.T) { t.Run(path, func(t *testing.T) { target := idp.RelyingPartyServer.URL + path resp := get(t, rpClient, target, navigateFetchHeaders...) - - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - assert.Equal(t, "not ok", resp.Body) + assertUpstreamUnauthorizedResponse(t, resp) }) } }) @@ -320,8 +334,7 @@ func TestReverseProxy(t *testing.T) { t.Run(path, func(t *testing.T) { target := idp.RelyingPartyServer.URL + path resp := get(t, rpClient, target, navigateFetchHeaders...) - - assert.Equal(t, http.StatusFound, resp.StatusCode) + assertAutoLoginRedirectResponse(t, idp, resp, path) }) } }) @@ -345,8 +358,7 @@ func TestReverseProxy(t *testing.T) { } resp := get(t, rpClient, idp.RelyingPartyServer.URL, header{"Authorization", "Bearer some-authorization"}) - assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - assert.Equal(t, "not ok", resp.Body) + assertUpstreamUnauthorizedResponse(t, resp) }) t.Run("should be overwritten if session found", func(t *testing.T) { @@ -359,8 +371,7 @@ func TestReverseProxy(t *testing.T) { } resp := get(t, rpClient, idp.RelyingPartyServer.URL, header{"Authorization", "Bearer some-authorization"}) - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, "ok", resp.Body) + assertUpstreamOKResponse(t, resp) }) }) @@ -390,7 +401,6 @@ func TestReverseProxy(t *testing.T) { {"X-Forwarded-Host", "wonderwall.example"}, {"X-Forwarded-Proto", "https"}, }...) - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, "ok", resp.Body) + assertUpstreamOKResponse(t, resp) }) }