refactor(handler/test): consistency passthrough, replace unneeded location parsing with stdlib function

This commit is contained in:
Trong Huu Nguyen
2022-07-15 10:23:56 +02:00
parent f6afc3cb6b
commit 9d32d100f0

View File

@@ -25,33 +25,32 @@ func TestHandler_Login(t *testing.T) {
rpClient := idp.RelyingPartyClient()
resp := localLogin(t, rpClient, idp)
defer resp.Body.Close()
location := resp.Header.Get("location")
u, err := url.Parse(location)
loginURL, err := resp.Location()
assert.NoError(t, err)
assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", u.Scheme, u.Host))
assert.Equal(t, "/authorize", u.Path)
assert.Equal(t, idp.OpenIDConfig.Client().GetACRValues(), u.Query().Get("acr_values"))
assert.Equal(t, idp.OpenIDConfig.Client().GetUILocales(), u.Query().Get("ui_locales"))
assert.Equal(t, idp.OpenIDConfig.Client().GetClientID(), u.Query().Get("client_id"))
assert.Equal(t, idp.OpenIDConfig.Client().GetCallbackURI(), u.Query().Get("redirect_uri"))
assert.Equal(t, "S256", u.Query().Get("code_challenge_method"))
assert.ElementsMatch(t, idp.OpenIDConfig.Client().GetScopes(), strings.Split(u.Query().Get("scope"), " "))
assert.NotEmpty(t, u.Query().Get("state"))
assert.NotEmpty(t, u.Query().Get("nonce"))
assert.NotEmpty(t, u.Query().Get("code_challenge"))
assert.Equal(t, idp.ProviderServer.URL, fmt.Sprintf("%s://%s", loginURL.Scheme, loginURL.Host))
assert.Equal(t, "/authorize", loginURL.Path)
assert.Equal(t, idp.OpenIDConfig.Client().GetACRValues(), loginURL.Query().Get("acr_values"))
assert.Equal(t, idp.OpenIDConfig.Client().GetUILocales(), loginURL.Query().Get("ui_locales"))
assert.Equal(t, idp.OpenIDConfig.Client().GetClientID(), loginURL.Query().Get("client_id"))
assert.Equal(t, idp.OpenIDConfig.Client().GetCallbackURI(), loginURL.Query().Get("redirect_uri"))
assert.Equal(t, "S256", loginURL.Query().Get("code_challenge_method"))
assert.ElementsMatch(t, idp.OpenIDConfig.Client().GetScopes(), strings.Split(loginURL.Query().Get("scope"), " "))
assert.NotEmpty(t, loginURL.Query().Get("state"))
assert.NotEmpty(t, loginURL.Query().Get("nonce"))
assert.NotEmpty(t, loginURL.Query().Get("code_challenge"))
resp, err = rpClient.Get(u.String())
resp, err = rpClient.Get(loginURL.String())
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
location = resp.Header.Get("location")
callbackURL, err := url.Parse(location)
callbackURL, err := resp.Location()
assert.NoError(t, err)
assert.Equal(t, u.Query().Get("state"), callbackURL.Query().Get("state"))
assert.Equal(t, loginURL.Query().Get("state"), callbackURL.Query().Get("state"))
assert.NotEmpty(t, callbackURL.Query().Get("code"))
}
@@ -73,10 +72,10 @@ func TestHandler_Logout(t *testing.T) {
login(t, rpClient, idp)
resp := localLogout(t, rpClient, idp)
defer resp.Body.Close()
// Get endsession endpoint after local logout
location := resp.Header.Get("location")
endsessionURL, err := url.Parse(location)
endsessionURL, err := resp.Location()
assert.NoError(t, err)
idpserverURL, err := url.Parse(idp.ProviderServer.URL)
@@ -149,10 +148,10 @@ func TestHandler_SessionStateRequired(t *testing.T) {
rpClient := idp.RelyingPartyClient()
resp := authorize(t, rpClient, idp)
defer resp.Body.Close()
// Get callback URL after successful auth
location := resp.Header.Get("location")
callbackURL, err := url.Parse(location)
callbackURL, err := resp.Location()
assert.NoError(t, err)
params := callbackURL.Query()
@@ -176,12 +175,12 @@ func TestHandler_Default(t *testing.T) {
upstream := httptest.NewServer(upstreamHandler)
defer upstream.Close()
upstreamHost, err := url.Parse(upstream.URL)
upstreamURL, err := url.Parse(upstream.URL)
assert.NoError(t, err)
t.Run("without auto-login", func(t *testing.T) {
cfg := mock.Config()
cfg.UpstreamHost = upstreamHost.Host
cfg.UpstreamHost = upstreamURL.Host
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
@@ -196,6 +195,7 @@ func TestHandler_Default(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "not ok", string(body))
// acquire session
login(t, rpClient, idp)
// retry request with session
@@ -212,7 +212,7 @@ func TestHandler_Default(t *testing.T) {
t.Run("with auto-login", func(t *testing.T) {
cfg := mock.Config()
cfg.AutoLogin = true
cfg.UpstreamHost = upstreamHost.Host
cfg.UpstreamHost = upstreamURL.Host
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
@@ -275,7 +275,6 @@ func localLogin(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider)
resp, err := rpClient.Get(loginURL.String())
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
cookies := rpClient.Jar.Cookies(loginURL)
@@ -292,16 +291,14 @@ func localLogin(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider)
func authorize(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) *http.Response {
resp := localLogin(t, rpClient, idp)
defer resp.Body.Close()
// Get authorization URL
location := resp.Header.Get("location")
u, err := url.Parse(location)
authorizeURL, err := resp.Location()
assert.NoError(t, err)
// Follow redirect to authorize with identity provider
resp, err = rpClient.Get(u.String())
resp, err = rpClient.Get(authorizeURL.String())
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
return resp
@@ -309,13 +306,13 @@ func authorize(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) *
func callback(t *testing.T, rpClient *http.Client, authorizeResponse *http.Response) *http.Cookie {
// Get callback URL after successful auth
location := authorizeResponse.Header.Get("location")
callbackURL, err := url.Parse(location)
callbackURL, err := authorizeResponse.Location()
assert.NoError(t, err)
// Follow redirect to callback
resp, err := rpClient.Get(callbackURL.String())
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
cookies := rpClient.Jar.Cookies(callbackURL)
@@ -332,6 +329,7 @@ func callback(t *testing.T, rpClient *http.Client, authorizeResponse *http.Respo
func login(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) *http.Cookie {
resp := authorize(t, rpClient, idp)
defer resp.Body.Close()
return callback(t, rpClient, resp)
}
@@ -355,10 +353,10 @@ func localLogout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider)
func logout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) {
resp := localLogout(t, rpClient, idp)
defer resp.Body.Close()
// Get endsession endpoint after local logout
location := resp.Header.Get("location")
endsessionURL, err := url.Parse(location)
endsessionURL, err := resp.Location()
assert.NoError(t, err)
// Follow redirect to endsession endpoint at identity provider
@@ -368,8 +366,7 @@ func logout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) {
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
// Get post-logout redirect URI after successful logout at identity provider
location = resp.Header.Get("location")
logoutCallbackURI, err := url.Parse(location)
logoutCallbackURI, err := resp.Location()
assert.NoError(t, err)
assert.Contains(t, logoutCallbackURI.String(), idp.OpenIDConfig.Client().GetLogoutCallbackURI())
@@ -378,11 +375,11 @@ func logout(t *testing.T, rpClient *http.Client, idp mock.IdentityProvider) {
// Follow redirect back to logout callback
resp, err = rpClient.Get(logoutCallbackURI.String())
assert.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
// Get post-logout redirect URI after redirect back to logout callback
location = resp.Header.Get("location")
postLogoutRedirectURI, err := url.Parse(location)
postLogoutRedirectURI, err := resp.Location()
assert.NoError(t, err)
assert.Equal(t, idp.OpenIDConfig.Client().GetPostLogoutRedirectURI(), postLogoutRedirectURI.String())