From 9d32d100f065fa5d82cf2ecfd090e12343f43dbc Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Fri, 15 Jul 2022 10:23:56 +0200 Subject: [PATCH] refactor(handler/test): consistency passthrough, replace unneeded location parsing with stdlib function --- pkg/handler/handler_test.go | 73 ++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 38 deletions(-) diff --git a/pkg/handler/handler_test.go b/pkg/handler/handler_test.go index b2f3644..e398816 100644 --- a/pkg/handler/handler_test.go +++ b/pkg/handler/handler_test.go @@ -25,33 +25,32 @@ func TestHandler_Login(t *testing.T) { rpClient := idp.RelyingPartyClient() resp := localLogin(t, rpClient, idp) + defer resp.Body.Close() - location := resp.Header.Get("location") - u, err := url.Parse(location) + loginURL, err := resp.Location() assert.NoError(t, err) - assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", u.Scheme, u.Host)) - assert.Equal(t, "/authorize", u.Path) - assert.Equal(t, idp.OpenIDConfig.Client().GetACRValues(), u.Query().Get("acr_values")) - assert.Equal(t, idp.OpenIDConfig.Client().GetUILocales(), u.Query().Get("ui_locales")) - assert.Equal(t, idp.OpenIDConfig.Client().GetClientID(), u.Query().Get("client_id")) - assert.Equal(t, idp.OpenIDConfig.Client().GetCallbackURI(), u.Query().Get("redirect_uri")) - assert.Equal(t, "S256", u.Query().Get("code_challenge_method")) - assert.ElementsMatch(t, idp.OpenIDConfig.Client().GetScopes(), strings.Split(u.Query().Get("scope"), " ")) - assert.NotEmpty(t, u.Query().Get("state")) - assert.NotEmpty(t, u.Query().Get("nonce")) - assert.NotEmpty(t, u.Query().Get("code_challenge")) + assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host)) + assert.Equal(t, "/authorize", loginURL.Path) + assert.Equal(t, idp.OpenIDConfig.Client().GetACRValues(), loginURL.Query().Get("acr_values")) + assert.Equal(t, idp.OpenIDConfig.Client().GetUILocales(), loginURL.Query().Get("ui_locales")) + assert.Equal(t, idp.OpenIDConfig.Client().GetClientID(), loginURL.Query().Get("client_id")) + assert.Equal(t, idp.OpenIDConfig.Client().GetCallbackURI(), loginURL.Query().Get("redirect_uri")) + assert.Equal(t, "S256", loginURL.Query().Get("code_challenge_method")) + assert.ElementsMatch(t, idp.OpenIDConfig.Client().GetScopes(), strings.Split(loginURL.Query().Get("scope"), " ")) + assert.NotEmpty(t, loginURL.Query().Get("state")) + assert.NotEmpty(t, loginURL.Query().Get("nonce")) + assert.NotEmpty(t, loginURL.Query().Get("code_challenge")) - resp, err = rpClient.Get(u.String()) + resp, err = rpClient.Get(loginURL.String()) assert.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - location = resp.Header.Get("location") - callbackURL, err := url.Parse(location) + callbackURL, err := resp.Location() assert.NoError(t, err) - assert.Equal(t, u.Query().Get("state"), callbackURL.Query().Get("state")) + assert.Equal(t, loginURL.Query().Get("state"), callbackURL.Query().Get("state")) assert.NotEmpty(t, callbackURL.Query().Get("code")) } @@ -73,10 +72,10 @@ func TestHandler_Logout(t *testing.T) { login(t, rpClient, idp) resp := localLogout(t, rpClient, idp) + defer resp.Body.Close() // Get endsession endpoint after local logout - location := resp.Header.Get("location") - endsessionURL, err := url.Parse(location) + endsessionURL, err := resp.Location() assert.NoError(t, err) idpserverURL, err := url.Parse(idp.ProviderServer.URL) @@ -149,10 +148,10 @@ func TestHandler_SessionStateRequired(t *testing.T) { rpClient := idp.RelyingPartyClient() resp := authorize(t, rpClient, idp) + defer resp.Body.Close() // Get callback URL after successful auth - location := resp.Header.Get("location") - callbackURL, err := url.Parse(location) + callbackURL, err := resp.Location() assert.NoError(t, err) params := callbackURL.Query() @@ -176,12 +175,12 @@ func TestHandler_Default(t *testing.T) { upstream := httptest.NewServer(upstreamHandler) defer upstream.Close() - upstreamHost, err := url.Parse(upstream.URL) + upstreamURL, err := url.Parse(upstream.URL) assert.NoError(t, err) t.Run("without auto-login", func(t *testing.T) { cfg := mock.Config() - cfg.UpstreamHost = upstreamHost.Host + cfg.UpstreamHost = upstreamURL.Host idp := mock.NewIdentityProvider(cfg) defer idp.Close() @@ -196,6 +195,7 @@ func TestHandler_Default(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "not ok", string(body)) + // acquire session login(t, rpClient, idp) // retry request with session @@ -212,7 +212,7 @@ func TestHandler_Default(t *testing.T) { t.Run("with auto-login", func(t *testing.T) { cfg := mock.Config() cfg.AutoLogin = true - cfg.UpstreamHost = upstreamHost.Host + cfg.UpstreamHost = upstreamURL.Host idp := mock.NewIdentityProvider(cfg) defer idp.Close() @@ -275,7 +275,6 @@ func localLogin(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) resp, err := rpClient.Get(loginURL.String()) assert.NoError(t, err) - defer resp.Body.Close() assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) cookies := rpClient.Jar.Cookies(loginURL) @@ -292,16 +291,14 @@ func localLogin(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) func authorize(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) *http.Response { resp := localLogin(t, rpClient, idp) + defer resp.Body.Close() - // Get authorization URL - location := resp.Header.Get("location") - u, err := url.Parse(location) + authorizeURL, err := resp.Location() assert.NoError(t, err) // Follow redirect to authorize with identity provider - resp, err = rpClient.Get(u.String()) + resp, err = rpClient.Get(authorizeURL.String()) assert.NoError(t, err) - defer resp.Body.Close() assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) return resp @@ -309,13 +306,13 @@ func authorize(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) * func callback(t *testing.T, rpClient *http.Client, authorizeResponse *http.Response) *http.Cookie { // Get callback URL after successful auth - location := authorizeResponse.Header.Get("location") - callbackURL, err := url.Parse(location) + callbackURL, err := authorizeResponse.Location() assert.NoError(t, err) // Follow redirect to callback resp, err := rpClient.Get(callbackURL.String()) assert.NoError(t, err) + defer resp.Body.Close() assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) cookies := rpClient.Jar.Cookies(callbackURL) @@ -332,6 +329,7 @@ func callback(t *testing.T, rpClient *http.Client, authorizeResponse *http.Respo func login(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) *http.Cookie { resp := authorize(t, rpClient, idp) + defer resp.Body.Close() return callback(t, rpClient, resp) } @@ -355,10 +353,10 @@ func localLogout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) func logout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) { resp := localLogout(t, rpClient, idp) + defer resp.Body.Close() // Get endsession endpoint after local logout - location := resp.Header.Get("location") - endsessionURL, err := url.Parse(location) + endsessionURL, err := resp.Location() assert.NoError(t, err) // Follow redirect to endsession endpoint at identity provider @@ -368,8 +366,7 @@ func logout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) { assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) // Get post-logout redirect URI after successful logout at identity provider - location = resp.Header.Get("location") - logoutCallbackURI, err := url.Parse(location) + logoutCallbackURI, err := resp.Location() assert.NoError(t, err) assert.Contains(t, logoutCallbackURI.String(), idp.OpenIDConfig.Client().GetLogoutCallbackURI()) @@ -378,11 +375,11 @@ func logout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) { // Follow redirect back to logout callback resp, err = rpClient.Get(logoutCallbackURI.String()) assert.NoError(t, err) + defer resp.Body.Close() assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) // Get post-logout redirect URI after redirect back to logout callback - location = resp.Header.Get("location") - postLogoutRedirectURI, err := url.Parse(location) + postLogoutRedirectURI, err := resp.Location() assert.NoError(t, err) assert.Equal(t, idp.OpenIDConfig.Client().GetPostLogoutRedirectURI(), postLogoutRedirectURI.String())