diff --git a/pkg/handler/handler_test.go b/pkg/handler/handler_test.go index b87f9dd..fe700d0 100644 --- a/pkg/handler/handler_test.go +++ b/pkg/handler/handler_test.go @@ -692,7 +692,7 @@ func request(t *testing.T, client *http.Client, method, url string, headers ...h assert.NoError(t, err) for _, h := range headers { - req.Header.Set(h.key, h.value) + req.Header.Add(h.key, h.value) } resp, err := client.Do(req) diff --git a/pkg/handler/reverseproxy.go b/pkg/handler/reverseproxy.go index 442f1cb..43a89f0 100644 --- a/pkg/handler/reverseproxy.go +++ b/pkg/handler/reverseproxy.go @@ -152,6 +152,13 @@ func handleAutologin(src ReverseProxySource, w http.ResponseWriter, r *http.Requ location := loginURL(target, "non-navigation request detected; responding with 401 and Location header") w.Header().Set("Location", location) w.WriteHeader(http.StatusUnauthorized) + + if accepts(r, "*/*", "application/json") { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"error": "unauthenticated, please log in"}`)) + } else { + w.Write([]byte("unauthenticated, please log in")) + } } func isNavigationRequest(r *http.Request) bool { @@ -168,10 +175,23 @@ func isNavigationRequest(r *http.Request) bool { } // fallback if browser doesn't support fetch metadata - acceptValues := strings.Split(r.Header.Get("Accept"), ",") - for _, v := range acceptValues { - if strings.ToLower(v) == "text/html" { - return true + return accepts(r, "text/html") +} + +func accepts(r *http.Request, accepted ...string) bool { + // iterate over all Accept headers + for _, header := range r.Header.Values("Accept") { + // iterate over all comma-separated values in a single Accept header + for _, v := range strings.Split(header, ",") { + v = strings.ToLower(v) + v = strings.TrimSpace(v) + v = strings.Split(v, ";")[0] + + for _, accept := range accepted { + if v == accept { + return true + } + } } } diff --git a/pkg/handler/reverseproxy_test.go b/pkg/handler/reverseproxy_test.go index f0a3a90..9fef9aa 100644 --- a/pkg/handler/reverseproxy_test.go +++ b/pkg/handler/reverseproxy_test.go @@ -37,7 +37,7 @@ func TestReverseProxy(t *testing.T) { assertAutoLoginUnauthorizedResponse := func(t *testing.T, idp *mock.IdentityProvider, resp response, originalReferer string) { assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) assert.Equal(t, loginURL(idp, originalReferer), resp.Location.String()) - assert.Empty(t, resp.Body) + assert.Equal(t, "unauthenticated, please log in", resp.Body) } // assert that the request is proxied to the upstream, which returns a 401 unauthorized @@ -147,7 +147,14 @@ func TestReverseProxy(t *testing.T) { rpClient := idp.RelyingPartyClient() resp := request(t, rpClient, method, idp.RelyingPartyServer.URL) - assertAutoLoginUnauthorizedResponse(t, idp, resp, "") + + if method == http.MethodHead { + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + assert.Equal(t, loginURL(idp, ""), resp.Location.String()) + assert.Empty(t, resp.Body) + } else { + assertAutoLoginUnauthorizedResponse(t, idp, resp, "") + } }) } }) @@ -184,12 +191,39 @@ func TestReverseProxy(t *testing.T) { target := idp.RelyingPartyServer.URL + "/some-path" - resp := get(t, rpClient, target, - header{"Sec-Fetch-Mode", ""}, - header{"Sec-Fetch-Dest", ""}, - header{"Accept", "text/html"}, - ) - assertAutoLoginRedirectResponse(t, idp, resp, "/some-path") + for _, tt := range []struct { + name string + headers []header + }{ + {"happy path", []header{ + {"Accept", "text/html"}, + }}, + {"multiple values", []header{ + {"Accept", "application/xhtml+xml, application/xml, text/html"}, + }}, + {"multiple accept headers", []header{ + {"Accept", "application/xhtml+xml, application/xml;q=0.9"}, + {"Accept", "text/plain"}, + {"Accept", "text/html"}, + }}, + {"non-canonical value", []header{ + {"Accept", ", text/HTML "}, + }}, + {"with quality parameter", []header{ + {"Accept", "text/html;q=0.9"}, + }}, + } { + t.Run(tt.name, func(t *testing.T) { + noFetchMetadata := []header{ + {"Sec-Fetch-Mode", ""}, + {"Sec-Fetch-Dest", ""}, + } + tt.headers = append(tt.headers, noFetchMetadata...) + + resp := get(t, rpClient, target, tt.headers...) + assertAutoLoginRedirectResponse(t, idp, resp, "/some-path") + }) + } }) t.Run("with auto-login and ignored paths", func(t *testing.T) {