fix(handler/reverseproxy): do not overwrite host header

This commit is contained in:
Trong Huu Nguyen
2022-08-16 09:24:44 +02:00
parent 758277a267
commit e460a5eab2
2 changed files with 25 additions and 8 deletions

View File

@@ -71,7 +71,6 @@ func newReverseProxy(upstreamHost string) *httputil.ReverseProxy {
// Instruct http.ReverseProxy to not modify X-Forwarded-For header
r.Header["X-Forwarded-For"] = nil
// Request should go to correct host
r.Host = upstreamHost
r.URL.Host = upstreamHost
r.URL.Scheme = "http"

View File

@@ -156,6 +156,7 @@ func TestHandler_Default(t *testing.T) {
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
up.SetReverseProxyUrl(idp.RelyingPartyServer.URL)
rpClient := idp.RelyingPartyClient()
// initial request without session
@@ -179,6 +180,7 @@ func TestHandler_Default(t *testing.T) {
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
up.SetReverseProxyUrl(idp.RelyingPartyServer.URL)
rpClient := idp.RelyingPartyClient()
// initial request without session
@@ -241,6 +243,7 @@ func TestHandler_Default(t *testing.T) {
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
up.SetReverseProxyUrl(idp.RelyingPartyServer.URL)
rpClient := idp.RelyingPartyClient()
t.Run("matched paths should not trigger login", func(t *testing.T) {
@@ -420,15 +423,31 @@ func get(t *testing.T, client *http.Client, url string) response {
}
type upstream struct {
Server *httptest.Server
URL *url.URL
Server *httptest.Server
URL *url.URL
reverseProxyURL *url.URL
}
func newUpstream(t *testing.T) upstream {
func (u *upstream) SetReverseProxyUrl(raw string) {
parsed, err := url.Parse(raw)
if err != nil {
panic(err)
}
u.reverseProxyURL = parsed
}
func newUpstream(t *testing.T) *upstream {
u := new(upstream)
upstreamHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
token := strings.TrimPrefix(authHeader, "Bearer ")
// Host should match the original authority from the ingress used to reach Wonderwall
assert.Equal(t, u.reverseProxyURL.Host, r.Host)
assert.NotEqual(t, u.URL.Host, r.Host)
if len(token) > 0 {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
@@ -442,10 +461,9 @@ func newUpstream(t *testing.T) upstream {
upstreamURL, err := url.Parse(server.URL)
assert.NoError(t, err)
return upstream{
Server: server,
URL: upstreamURL,
}
u.Server = server
u.URL = upstreamURL
return u
}
func getCookieFromJar(name string, cookies []*http.Cookie) *http.Cookie {