mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-06 00:17:27 +00:00
324 lines
7.9 KiB
Go
324 lines
7.9 KiB
Go
package handler_test
|
|
|
|
import (
|
|
"net/http"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/nais/wonderwall/pkg/mock"
|
|
urlpkg "github.com/nais/wonderwall/pkg/url"
|
|
)
|
|
|
|
func TestReverseProxy(t *testing.T) {
|
|
up := newUpstream(t)
|
|
defer up.Server.Close()
|
|
|
|
t.Run("without auto-login", func(t *testing.T) {
|
|
cfg := mock.Config()
|
|
cfg.UpstreamHost = up.URL.Host
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
up.SetIdentityProvider(idp)
|
|
rpClient := idp.RelyingPartyClient()
|
|
|
|
// initial request without session
|
|
resp := get(t, rpClient, idp.RelyingPartyServer.URL)
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
assert.Equal(t, "not ok", resp.Body)
|
|
|
|
// acquire session
|
|
login(t, rpClient, idp)
|
|
|
|
// retry request with session
|
|
resp = get(t, rpClient, idp.RelyingPartyServer.URL)
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
assert.Equal(t, "ok", resp.Body)
|
|
})
|
|
|
|
t.Run("with auto-login", func(t *testing.T) {
|
|
cfg := mock.Config()
|
|
cfg.AutoLogin = true
|
|
cfg.UpstreamHost = up.URL.Host
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
up.SetIdentityProvider(idp)
|
|
rpClient := idp.RelyingPartyClient()
|
|
|
|
// initial request without session
|
|
target := idp.RelyingPartyServer.URL + "/"
|
|
|
|
resp := get(t, rpClient, target)
|
|
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
|
|
|
// redirect should point to local login endpoint
|
|
loginLocation := resp.Location
|
|
assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login?redirect=%2F", loginLocation.String())
|
|
|
|
// follow redirect to local login endpoint
|
|
resp = get(t, rpClient, loginLocation.String())
|
|
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
|
|
|
// redirect should point to identity provider
|
|
authorizeLocation := resp.Location
|
|
|
|
authorizeEndpoint := *authorizeLocation
|
|
authorizeEndpoint.RawQuery = ""
|
|
assert.Equal(t, idp.OpenIDConfig.Provider().AuthorizationEndpoint(), authorizeEndpoint.String())
|
|
|
|
// follow redirect to identity provider for login
|
|
resp = get(t, rpClient, authorizeLocation.String())
|
|
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
|
|
|
// redirect should point back to relying party
|
|
callbackLocation := resp.Location
|
|
|
|
callbackEndpoint := *callbackLocation
|
|
callbackEndpoint.RawQuery = ""
|
|
|
|
req := idp.GetRequest(callbackLocation.String())
|
|
expectedCallbackURL, err := urlpkg.LoginCallback(req)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, expectedCallbackURL, callbackEndpoint.String())
|
|
|
|
// follow redirect back to relying party
|
|
resp = get(t, rpClient, callbackLocation.String())
|
|
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
|
|
|
// finally, follow redirect back to original target, now with a session
|
|
targetLocation := resp.Location
|
|
assert.Equal(t, target, targetLocation.String())
|
|
|
|
resp = get(t, rpClient, targetLocation.String())
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
assert.Equal(t, "ok", resp.Body)
|
|
})
|
|
|
|
t.Run("with auto-login for non-GET requests", func(t *testing.T) {
|
|
for _, method := range []string{
|
|
http.MethodConnect,
|
|
http.MethodDelete,
|
|
http.MethodHead,
|
|
http.MethodOptions,
|
|
http.MethodPatch,
|
|
http.MethodPost,
|
|
http.MethodPut,
|
|
http.MethodTrace,
|
|
} {
|
|
t.Run(method, func(t *testing.T) {
|
|
cfg := mock.Config()
|
|
cfg.AutoLogin = true
|
|
cfg.UpstreamHost = up.URL.Host
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
up.SetIdentityProvider(idp)
|
|
rpClient := idp.RelyingPartyClient()
|
|
|
|
req, err := http.NewRequest(method, idp.RelyingPartyServer.URL, nil)
|
|
assert.NoError(t, err)
|
|
|
|
resp, err := rpClient.Do(req)
|
|
assert.NoError(t, err)
|
|
defer resp.Body.Close()
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("with auto-login and ignored paths", func(t *testing.T) {
|
|
for pattern, tt := range map[string]struct {
|
|
match []string
|
|
nonMatch []string
|
|
}{
|
|
"/": {
|
|
match: []string{
|
|
"/",
|
|
"",
|
|
},
|
|
nonMatch: []string{
|
|
"/a",
|
|
"/a/b",
|
|
},
|
|
},
|
|
"/exact/match": {
|
|
match: []string{
|
|
"/exact/match",
|
|
"/exact/match/",
|
|
},
|
|
nonMatch: []string{
|
|
"/exact/match/huh",
|
|
},
|
|
},
|
|
"/allowed": {
|
|
match: []string{
|
|
"/allowed",
|
|
"/allowed/",
|
|
},
|
|
nonMatch: []string{
|
|
"/allowe",
|
|
"/allowed/no",
|
|
"/not-allowed",
|
|
"/not-allowed/allowed",
|
|
},
|
|
},
|
|
"/wildcard/*": {
|
|
match: []string{
|
|
"/wildcard/very",
|
|
"/wildcard/very/",
|
|
},
|
|
nonMatch: []string{
|
|
"/wildcard",
|
|
"/wildcard/",
|
|
"/wildcard/yup/nope",
|
|
},
|
|
},
|
|
"/deeper/*/*": {
|
|
match: []string{
|
|
"/deeper/1/2",
|
|
"/deeper/1/2/",
|
|
},
|
|
nonMatch: []string{
|
|
"/deeper",
|
|
"/deeper/",
|
|
"/deeper/1",
|
|
"/deeper/1/",
|
|
"/deeper/1/2/3",
|
|
},
|
|
},
|
|
"/any*": {
|
|
match: []string{
|
|
"/any",
|
|
"/any/",
|
|
"/anything",
|
|
"/anything/",
|
|
"/anywho",
|
|
"/anywho/",
|
|
},
|
|
nonMatch: []string{
|
|
"/any/thing",
|
|
"/any/thing/",
|
|
"/anywho/mst/ve",
|
|
},
|
|
},
|
|
"/trailing/": {
|
|
match: []string{
|
|
"/trailing",
|
|
"/trailing/",
|
|
},
|
|
nonMatch: []string{
|
|
"/trailing/path",
|
|
"/trailing/path/",
|
|
},
|
|
},
|
|
"/nested/**": {
|
|
match: []string{
|
|
"/nested",
|
|
"/nested/",
|
|
"/nested/very",
|
|
"/nested/very/deep",
|
|
"/nested/very/deep/deeper",
|
|
},
|
|
nonMatch: []string{
|
|
"/not/nested",
|
|
"/not/nested/very",
|
|
},
|
|
},
|
|
"/static/**/*.js": {
|
|
match: []string{
|
|
"/static/bundle.js",
|
|
"/static/min/bundle.js",
|
|
"/static/vendor/min/bundle.js",
|
|
},
|
|
nonMatch: []string{
|
|
"/static",
|
|
"/static/",
|
|
"/static/some.css",
|
|
"/static/min",
|
|
"/static/min/",
|
|
"/static/min/some.css",
|
|
"/static/vendor/min/some.css",
|
|
},
|
|
},
|
|
} {
|
|
t.Run(pattern, func(t *testing.T) {
|
|
cfg := mock.Config()
|
|
cfg.UpstreamHost = up.URL.Host
|
|
cfg.AutoLogin = true
|
|
cfg.AutoLoginIgnorePaths = []string{pattern}
|
|
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
up.SetIdentityProvider(idp)
|
|
rpClient := idp.RelyingPartyClient()
|
|
|
|
t.Run("match", func(t *testing.T) {
|
|
for _, path := range tt.match {
|
|
t.Run(path, func(t *testing.T) {
|
|
target := idp.RelyingPartyServer.URL + path
|
|
resp := get(t, rpClient, target)
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
assert.Equal(t, "not ok", resp.Body)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("non-match", func(t *testing.T) {
|
|
for _, path := range tt.nonMatch {
|
|
t.Run(path, func(t *testing.T) {
|
|
target := idp.RelyingPartyServer.URL + path
|
|
resp := get(t, rpClient, target)
|
|
|
|
assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode)
|
|
})
|
|
}
|
|
})
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("request with authorization header set", func(t *testing.T) {
|
|
cfg := mock.Config()
|
|
cfg.UpstreamHost = up.URL.Host
|
|
idp := mock.NewIdentityProvider(cfg)
|
|
defer idp.Close()
|
|
|
|
up.SetIdentityProvider(idp)
|
|
rpClient := idp.RelyingPartyClient()
|
|
|
|
t.Run("should be preserved if no session found", func(t *testing.T) {
|
|
up.requestCallback = func(r *http.Request) {
|
|
authorization := r.Header.Get("Authorization")
|
|
assert.Equal(t, "Bearer some-authorization", authorization)
|
|
}
|
|
|
|
resp := getWithHeaders(t, rpClient, idp.RelyingPartyServer.URL, map[string]string{
|
|
"Authorization": "Bearer some-authorization",
|
|
})
|
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
|
assert.Equal(t, "not ok", resp.Body)
|
|
})
|
|
|
|
t.Run("should be overwritten if session found", func(t *testing.T) {
|
|
// acquire session
|
|
login(t, rpClient, idp)
|
|
|
|
up.requestCallback = func(r *http.Request) {
|
|
authorization := r.Header.Get("Authorization")
|
|
assert.NotEqual(t, "Bearer some-authorization", authorization)
|
|
}
|
|
|
|
resp := getWithHeaders(t, rpClient, idp.RelyingPartyServer.URL, map[string]string{
|
|
"Authorization": "Bearer some-authorization",
|
|
})
|
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
|
assert.Equal(t, "ok", resp.Body)
|
|
})
|
|
})
|
|
}
|