diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index f472f50..aeb5f09 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -71,7 +71,6 @@ func newReverseProxy(upstreamHost string) *httputil.ReverseProxy { // Instruct http.ReverseProxy to not modify X-Forwarded-For header r.Header["X-Forwarded-For"] = nil // Request should go to correct host - r.Host = upstreamHost r.URL.Host = upstreamHost r.URL.Scheme = "http" diff --git a/pkg/handler/handler_test.go b/pkg/handler/handler_test.go index 70803df..b068017 100644 --- a/pkg/handler/handler_test.go +++ b/pkg/handler/handler_test.go @@ -156,6 +156,7 @@ func TestHandler_Default(t *testing.T) { idp := mock.NewIdentityProvider(cfg) defer idp.Close() + up.SetReverseProxyUrl(idp.RelyingPartyServer.URL) rpClient := idp.RelyingPartyClient() // initial request without session @@ -179,6 +180,7 @@ func TestHandler_Default(t *testing.T) { idp := mock.NewIdentityProvider(cfg) defer idp.Close() + up.SetReverseProxyUrl(idp.RelyingPartyServer.URL) rpClient := idp.RelyingPartyClient() // initial request without session @@ -241,6 +243,7 @@ func TestHandler_Default(t *testing.T) { idp := mock.NewIdentityProvider(cfg) defer idp.Close() + up.SetReverseProxyUrl(idp.RelyingPartyServer.URL) rpClient := idp.RelyingPartyClient() t.Run("matched paths should not trigger login", func(t *testing.T) { @@ -420,15 +423,31 @@ func get(t *testing.T, client *http.Client, url string) response { } type upstream struct { - Server *httptest.Server - URL *url.URL + Server *httptest.Server + URL *url.URL + reverseProxyURL *url.URL } -func newUpstream(t *testing.T) upstream { +func (u *upstream) SetReverseProxyUrl(raw string) { + parsed, err := url.Parse(raw) + if err != nil { + panic(err) + } + + u.reverseProxyURL = parsed +} + +func newUpstream(t *testing.T) *upstream { + u := new(upstream) + upstreamHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { authHeader := r.Header.Get("Authorization") token := strings.TrimPrefix(authHeader, "Bearer ") + // Host should match the original authority from the ingress used to reach Wonderwall + assert.Equal(t, u.reverseProxyURL.Host, r.Host) + assert.NotEqual(t, u.URL.Host, r.Host) + if len(token) > 0 { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) @@ -442,10 +461,9 @@ func newUpstream(t *testing.T) upstream { upstreamURL, err := url.Parse(server.URL) assert.NoError(t, err) - return upstream{ - Server: server, - URL: upstreamURL, - } + u.Server = server + u.URL = upstreamURL + return u } func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie {