diff --git a/pkg/handler/handler_test.go b/pkg/handler/handler_test.go index 9e708c0..cc5e703 100644 --- a/pkg/handler/handler_test.go +++ b/pkg/handler/handler_test.go @@ -3,6 +3,7 @@ package handler_test import ( "context" "encoding/base64" + "errors" "fmt" "io/ioutil" "net/http" @@ -25,10 +26,7 @@ func TestHandler_Login(t *testing.T) { rpClient := idp.RelyingPartyClient() resp := localLogin(t, rpClient, idp) - defer resp.Body.Close() - - loginURL, err := resp.Location() - assert.NoError(t, err) + loginURL := resp.Location assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host)) assert.Equal(t, "/authorize", loginURL.Path) @@ -42,14 +40,10 @@ func TestHandler_Login(t *testing.T) { assert.NotEmpty(t, loginURL.Query().Get("nonce")) assert.NotEmpty(t, loginURL.Query().Get("code_challenge")) - resp, err = rpClient.Get(loginURL.String()) - assert.NoError(t, err) - defer resp.Body.Close() + resp = get(t, rpClient, loginURL.String()) assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) - callbackURL, err := resp.Location() - assert.NoError(t, err) - + callbackURL := resp.Location assert.Equal(t, loginURL.Query().Get("state"), callbackURL.Query().Get("state")) assert.NotEmpty(t, callbackURL.Query().Get("code")) } @@ -72,11 +66,9 @@ func TestHandler_Logout(t *testing.T) { login(t, rpClient, idp) resp := localLogout(t, rpClient, idp) - defer resp.Body.Close() // Get endsession endpoint after local logout - endsessionURL, err := resp.Location() - assert.NoError(t, err) + endsessionURL := resp.Location idpserverURL, err := url.Parse(idp.ProviderServer.URL) assert.NoError(t, err) @@ -133,9 +125,7 @@ func TestHandler_FrontChannelLogout(t *testing.T) { values.Add("iss", idp.OpenIDConfig.Provider().Issuer) frontchannelLogoutURL.RawQuery = values.Encode() - resp, err := rpClient.Get(frontchannelLogoutURL.String()) - assert.NoError(t, err) - defer resp.Body.Close() + resp := get(t, rpClient, frontchannelLogoutURL.String()) assert.Equal(t, http.StatusOK, resp.StatusCode) } @@ -148,72 +138,43 @@ func TestHandler_SessionStateRequired(t *testing.T) { rpClient := idp.RelyingPartyClient() resp := authorize(t, rpClient, idp) - defer resp.Body.Close() // Get callback URL after successful auth - callbackURL, err := resp.Location() - assert.NoError(t, err) - - params := callbackURL.Query() + params := resp.Location.Query() sessionState := params.Get("session_state") assert.NotEmpty(t, sessionState) } func TestHandler_Default(t *testing.T) { - upstreamHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - authHeader := r.Header.Get("Authorization") - token := strings.TrimPrefix(authHeader, "Bearer ") - - if len(token) > 0 { - w.WriteHeader(http.StatusOK) - _, _ = w.Write([]byte("ok")) - } else { - w.WriteHeader(http.StatusUnauthorized) - _, _ = w.Write([]byte("not ok")) - } - }) - upstream := httptest.NewServer(upstreamHandler) - defer upstream.Close() - - upstreamURL, err := url.Parse(upstream.URL) - assert.NoError(t, err) + up := newUpstream(t) + defer up.Server.Close() t.Run("without auto-login", func(t *testing.T) { cfg := mock.Config() - cfg.UpstreamHost = upstreamURL.Host + cfg.UpstreamHost = up.URL.Host idp := mock.NewIdentityProvider(cfg) defer idp.Close() rpClient := idp.RelyingPartyClient() // initial request without session - resp, err := rpClient.Get(idp.RelyingPartyServer.URL) - assert.NoError(t, err) - defer resp.Body.Close() + resp := get(t, rpClient, idp.RelyingPartyServer.URL) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - - body, err := ioutil.ReadAll(resp.Body) - assert.NoError(t, err) - assert.Equal(t, "not ok", string(body)) + assert.Equal(t, "not ok", resp.Body) // acquire session login(t, rpClient, idp) // retry request with session - resp, err = rpClient.Get(idp.RelyingPartyServer.URL) - assert.NoError(t, err) - defer resp.Body.Close() + resp = get(t, rpClient, idp.RelyingPartyServer.URL) assert.Equal(t, http.StatusOK, resp.StatusCode) - - body, err = ioutil.ReadAll(resp.Body) - assert.NoError(t, err) - assert.Equal(t, "ok", string(body)) + assert.Equal(t, "ok", resp.Body) }) t.Run("with auto-login", func(t *testing.T) { cfg := mock.Config() cfg.AutoLogin = true - cfg.UpstreamHost = upstreamURL.Host + cfg.UpstreamHost = up.URL.Host idp := mock.NewIdentityProvider(cfg) defer idp.Close() @@ -222,55 +183,43 @@ func TestHandler_Default(t *testing.T) { // initial request without session target := idp.RelyingPartyServer.URL + "/" - resp, err := rpClient.Get(target) - assert.NoError(t, err) - defer resp.Body.Close() + resp := get(t, rpClient, target) assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) // redirect should point to identity provider - authorizeLocation, err := resp.Location() - assert.NoError(t, err) + authorizeLocation := resp.Location + authorizeEndpoint := *authorizeLocation authorizeEndpoint.RawQuery = "" assert.Equal(t, idp.OpenIDConfig.Provider().AuthorizationEndpoint, authorizeEndpoint.String()) // follow redirect to identity provider for login - resp, err = rpClient.Get(authorizeLocation.String()) - assert.NoError(t, err) - defer resp.Body.Close() + resp = get(t, rpClient, authorizeLocation.String()) assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) // redirect should point back to relying party - callbackLocation, err := resp.Location() - assert.NoError(t, err) + callbackLocation := resp.Location + callbackEndpoint := *callbackLocation callbackEndpoint.RawQuery = "" assert.Equal(t, idp.OpenIDConfig.Client().GetCallbackURI(), callbackEndpoint.String()) // follow redirect back to relying party - resp, err = rpClient.Get(callbackLocation.String()) - assert.NoError(t, err) - defer resp.Body.Close() + resp = get(t, rpClient, callbackLocation.String()) assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) // finally, follow redirect back to original target, now with a session - targetLocation, err := resp.Location() - assert.NoError(t, err) + targetLocation := resp.Location assert.Equal(t, target, targetLocation.String()) - resp, err = rpClient.Get(targetLocation.String()) - assert.NoError(t, err) - defer resp.Body.Close() + resp = get(t, rpClient, targetLocation.String()) assert.Equal(t, http.StatusOK, resp.StatusCode) - - body, err := ioutil.ReadAll(resp.Body) - assert.NoError(t, err) - assert.Equal(t, "ok", string(body)) + assert.Equal(t, "ok", resp.Body) }) t.Run("with auto-login and skipped paths", func(t *testing.T) { cfg := mock.Config() - cfg.UpstreamHost = upstreamURL.Host + cfg.UpstreamHost = up.URL.Host cfg.AutoLogin = true cfg.AutoLoginSkipPaths = []string{ "^/exact/match$", @@ -300,15 +249,11 @@ func TestHandler_Default(t *testing.T) { } for _, path := range matched { t.Run(path, func(t *testing.T) { - resp, err := rpClient.Get(idp.RelyingPartyServer.URL + path) - assert.NoError(t, err) - defer resp.Body.Close() + target := idp.RelyingPartyServer.URL + path + resp := get(t, rpClient, target) assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) - - body, err := ioutil.ReadAll(resp.Body) - assert.NoError(t, err) - assert.Equal(t, "not ok", string(body)) + assert.Equal(t, "not ok", resp.Body) }) } }) @@ -327,9 +272,8 @@ func TestHandler_Default(t *testing.T) { } for _, path := range nonMatched { t.Run(path, func(t *testing.T) { - resp, err := rpClient.Get(idp.RelyingPartyServer.URL + path) - assert.NoError(t, err) - defer resp.Body.Close() + target := idp.RelyingPartyServer.URL + path + resp := get(t, rpClient, target) assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) }) @@ -338,13 +282,12 @@ func TestHandler_Default(t *testing.T) { }) } -func localLogin(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) *http.Response { +func localLogin(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) response { // First, run /oauth2/login to set cookies loginURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/login") assert.NoError(t, err) - resp, err := rpClient.Get(loginURL.String()) - assert.NoError(t, err) + resp := get(t, rpClient, loginURL.String()) assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) cookies := rpClient.Jar.Cookies(loginURL) @@ -359,30 +302,22 @@ func localLogin(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) return resp } -func authorize(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) *http.Response { +func authorize(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) response { resp := localLogin(t, rpClient, idp) - defer resp.Body.Close() - - authorizeURL, err := resp.Location() - assert.NoError(t, err) // Follow redirect to authorize with identity provider - resp, err = rpClient.Get(authorizeURL.String()) - assert.NoError(t, err) + resp = get(t, rpClient, resp.Location.String()) assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) return resp } -func callback(t *testing.T, rpClient *http.Client, authorizeResponse *http.Response) *http.Cookie { +func callback(t *testing.T, rpClient *http.Client, authorizeResponse response) *http.Cookie { // Get callback URL after successful auth - callbackURL, err := authorizeResponse.Location() - assert.NoError(t, err) + callbackURL := authorizeResponse.Location // Follow redirect to callback - resp, err := rpClient.Get(callbackURL.String()) - assert.NoError(t, err) - defer resp.Body.Close() + resp := get(t, rpClient, callbackURL.String()) assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) cookies := rpClient.Jar.Cookies(callbackURL) @@ -399,18 +334,15 @@ func callback(t *testing.T, rpClient *http.Client, authorizeResponse *http.Respo func login(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) *http.Cookie { resp := authorize(t, rpClient, idp) - defer resp.Body.Close() return callback(t, rpClient, resp) } -func localLogout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) *http.Response { +func localLogout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) response { // Request self-initiated logout logoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout") assert.NoError(t, err) - resp, err := rpClient.Get(logoutURL.String()) - assert.NoError(t, err) - defer resp.Body.Close() + resp := get(t, rpClient, logoutURL.String()) assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) cookies := rpClient.Jar.Cookies(logoutURL) @@ -422,36 +354,24 @@ func localLogout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) } func logout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) { - resp := localLogout(t, rpClient, idp) - defer resp.Body.Close() - // Get endsession endpoint after local logout - endsessionURL, err := resp.Location() - assert.NoError(t, err) + resp := localLogout(t, rpClient, idp) // Follow redirect to endsession endpoint at identity provider - resp, err = rpClient.Get(endsessionURL.String()) - assert.NoError(t, err) - defer resp.Body.Close() + resp = get(t, rpClient, resp.Location.String()) assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) // Get post-logout redirect URI after successful logout at identity provider - logoutCallbackURI, err := resp.Location() - assert.NoError(t, err) + logoutCallbackURI := resp.Location assert.Contains(t, logoutCallbackURI.String(), idp.OpenIDConfig.Client().GetLogoutCallbackURI()) - assert.Equal(t, "/oauth2/logout/callback", logoutCallbackURI.Path) // Follow redirect back to logout callback - resp, err = rpClient.Get(logoutCallbackURI.String()) - assert.NoError(t, err) - defer resp.Body.Close() + resp = get(t, rpClient, logoutCallbackURI.String()) assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) // Get post-logout redirect URI after redirect back to logout callback - postLogoutRedirectURI, err := resp.Location() - assert.NoError(t, err) - assert.Equal(t, idp.OpenIDConfig.Client().GetPostLogoutRedirectURI(), postLogoutRedirectURI.String()) + assert.Equal(t, idp.OpenIDConfig.Client().GetPostLogoutRedirectURI(), resp.Location.String()) cookies := rpClient.Jar.Cookies(logoutCallbackURI) sessionCookie := getCookieFromJar(cookie.Session, cookies) @@ -459,6 +379,61 @@ func logout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) { assert.Nil(t, sessionCookie) } +type response struct { + Body string + Location *url.URL + StatusCode int +} + +func get(t *testing.T, client *http.Client, url string) response { + resp, err := client.Get(url) + assert.NoError(t, err) + defer resp.Body.Close() + + location, err := resp.Location() + if !errors.Is(http.ErrNoLocation, err) { + assert.NoError(t, err) + } + + body, err := ioutil.ReadAll(resp.Body) + assert.NoError(t, err) + + return response{ + Body: string(body), + Location: location, + StatusCode: resp.StatusCode, + } +} + +type upstream struct { + Server *httptest.Server + URL *url.URL +} + +func newUpstream(t *testing.T) upstream { + upstreamHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + token := strings.TrimPrefix(authHeader, "Bearer ") + + if len(token) > 0 { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + } else { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte("not ok")) + } + }) + server := httptest.NewServer(upstreamHandler) + + upstreamURL, err := url.Parse(server.URL) + assert.NoError(t, err) + + return upstream{ + Server: server, + URL: upstreamURL, + } +} + func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie { for _, c := range cookies { if c.Name == name {