feat: restrict non-navigational requests to oauth2-routes for all modes

This commit is contained in:
Trong Huu Nguyen
2025-03-06 20:27:51 +01:00
parent 48527e47f8
commit 126db31d25
6 changed files with 59 additions and 26 deletions

View File

@@ -25,19 +25,16 @@ The `auto-login` option will configure Wonderwall to enforce authentication for
If the user is _unauthenticated_ or has an [_inactive_ or _expired_ session](sessions.md), all requests will be short-circuited (i.e. return early and **not** proxied to your application).
The short-circuited response depends on whether the request is a _top-level navigation_ request or not.
A _top-level navigation_ request has the following properties:
1. Is a `GET` request
2. Has the Fetch metadata headers `Sec-Fetch-Dest=document` and `Sec-Fetch-Mode=navigate`
If the user agent does not support the Fetch metadata headers, we look for an `Accept` header that includes `text/html`, which all major browsers send for navigation requests.
A _top-level navigation request_ is a `GET` request that has the [Fetch metadata request headers](https://developer.mozilla.org/en-US/docs/Glossary/Fetch_metadata_request_header) `Sec-Fetch-Dest=document` and `Sec-Fetch-Mode=navigate`.
If the user agent does not support the Fetch metadata headers, we look for an `Accept` header that includes `text/html`, which all major browsers send for navigation requests.
Internet Explorer 8 won't work with this of course, so hopefully you're not in a position that requires supporting this browser.
A _top-level navigation_ request results in a HTTP 302 Found response with the `Location` header pointing to [the `/oauth2/login` endpoint](endpoints.md#oauth2login).
A top-level navigation request results in a HTTP 302 Found response with the `Location` header pointing to [the `/oauth2/login` endpoint](endpoints.md#oauth2login).
The `redirect` parameter in the login URL is set to the value found in the `Referer` header, so that the user is redirected back to their intended location after login.
If the `Referer` header is empty, the `redirect` parameter is set to the matching ingress path for the original request.
Other requests are considered non-navigational requests, and they will result in a HTTP 401 Unauthorized response with the `Location` header set as described above.
Other requests are considered non-navigational requests and result in a HTTP 401 Unauthorized response with the `Location` header set as described above.
For defence in depth, you should still check the `Authorization` header for a token and validate the token even when using auto-login.

View File

@@ -13,9 +13,12 @@ import (
// This should only be used for endpoints that are only supposed to be _navigated to_ from a browser.
// The 401 response prevents redirecting non-navigation requests to the identity provider, which usually results in
// a CORS error for typical Fetch or XHR requests from the browser.
//
// This depends on the presence of the Fetch metadata headers, mostly present in modern browsers.
// For compatibility with older browsers, requests without these headers are still allowed to pass through.
func DisallowNonNavigationalRequests(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !IsNavigationRequest(r) {
if HasSecFetchMetadata(r) && !IsNavigationRequest(r) {
span := trace.SpanFromContext(r.Context())
span.SetAttributes(attribute.Bool("request.disallowed", true))

View File

@@ -8,21 +8,26 @@ import (
"github.com/nais/wonderwall/pkg/cookie"
)
// IsNavigationRequest checks if the request is a navigation request by using Sec-Fetch headers.
// This is used to separate between redirects for browser navigation and redirects for resource requests (e.g., Fetch or XHR).
// We fall back to checking the Accept header if the browser doesn't support fetch metadata.
func IsNavigationRequest(r *http.Request) bool {
// we assume that navigation requests are always GET requests
if r.Method != http.MethodGet {
return false
}
// check for top-level navigation requests
mode := r.Header.Get("Sec-Fetch-Mode")
dest := r.Header.Get("Sec-Fetch-Dest")
if mode != "" && dest != "" {
return mode == "navigate" && dest == "document"
if mode == "" && dest == "" {
return Accepts(r, "text/html")
}
// fallback if browser doesn't support fetch metadata
return Accepts(r, "text/html")
return mode == "navigate" && dest == "document"
}
func HasSecFetchMetadata(r *http.Request) bool {
return r.Header.Get("Sec-Fetch-Mode") != "" && r.Header.Get("Sec-Fetch-Dest") != ""
}
func Accepts(r *http.Request, accepted ...string) bool {

View File

@@ -362,8 +362,6 @@ func TestPing(t *testing.T) {
func TestNonNavigationalRequests(t *testing.T) {
cfg := mock.Config()
cfg.SSO.Enabled = true
cfg.Session.ForwardAuth = true
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
@@ -373,9 +371,21 @@ func TestNonNavigationalRequests(t *testing.T) {
"/oauth2/logout",
"/oauth2/logout/callback",
} {
rpClient := idp.RelyingPartyClient()
resp := get(t, rpClient, idp.RelyingPartyServer.URL+path)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
t.Run("with fetch metadata", func(t *testing.T) {
rpClient := idp.RelyingPartyClient()
resp := get(t, rpClient, idp.RelyingPartyServer.URL+path,
header{"Sec-Fetch-Mode", "cors"},
header{"Sec-Fetch-Dest", "empty"},
)
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
})
t.Run("without fetch metadata", func(t *testing.T) {
rpClient := idp.RelyingPartyClient()
resp := get(t, rpClient, idp.RelyingPartyServer.URL+path)
assert.GreaterOrEqual(t, resp.StatusCode, http.StatusFound)
assert.LessOrEqual(t, resp.StatusCode, http.StatusPermanentRedirect)
})
}
}

View File

@@ -164,14 +164,25 @@ func TestReverseProxy(t *testing.T) {
up.SetIdentityProvider(idp)
rpClient := idp.RelyingPartyClient()
target := idp.RelyingPartyServer.URL + "/"
resp := get(t, rpClient, target)
assertAutoLoginUnauthorizedResponse(t, idp, resp, "")
t.Run("without fetch metadata", func(t *testing.T) {
target := idp.RelyingPartyServer.URL + "/"
resp := get(t, rpClient, target)
assertAutoLoginUnauthorizedResponse(t, idp, resp, "")
referer := idp.RelyingPartyServer.URL + "/some-path"
target = idp.RelyingPartyServer.URL + "/some-path/resource"
resp = get(t, rpClient, target, header{"Referer", referer})
assertAutoLoginUnauthorizedResponse(t, idp, resp, referer)
referer := idp.RelyingPartyServer.URL + "/some-path"
target = idp.RelyingPartyServer.URL + "/some-path/resource"
resp = get(t, rpClient, target, header{"Referer", referer})
assertAutoLoginUnauthorizedResponse(t, idp, resp, referer)
})
t.Run("with fetch metadata", func(t *testing.T) {
target := idp.RelyingPartyServer.URL + "/"
resp := get(t, rpClient, target,
header{"Sec-Fetch-Mode", "cors"},
header{"Sec-Fetch-Dest", "empty"},
)
assertAutoLoginUnauthorizedResponse(t, idp, resp, "")
})
})
t.Run("with auto-login for navigation request without fetch metadata returns 3xx redirect", func(t *testing.T) {

View File

@@ -2,6 +2,7 @@ package router
import (
"net/http"
"strings"
"github.com/go-chi/chi/v5"
chi_middleware "github.com/go-chi/chi/v5/middleware"
@@ -62,6 +63,9 @@ func New(src Source, cfg *config.Config) chi.Router {
r.Use(otelchi.Middleware(cfg.OpenTelemetry.ServiceName,
otelchi.WithChiRoutes(r),
otelchi.WithRequestMethodInSpanName(true),
otelchi.WithFilter(func(r *http.Request) bool {
return !strings.HasSuffix(r.URL.Path, paths.OAuth2+paths.Ping)
}),
))
r.Use(otel.Middleware)
}
@@ -82,13 +86,16 @@ func New(src Source, cfg *config.Config) chi.Router {
for _, prefix := range prefixes {
r.Route(prefix+paths.OAuth2, func(r chi.Router) {
r.Group(func(r chi.Router) {
if cfg.Session.ForwardAuth {
if cfg.SSO.IsServer() {
r.Use(cors(http.MethodGet, http.MethodHead))
r.Use(httpinternal.DisallowNonNavigationalRequests)
// Cors middleware is designed to be used as a top-level middleware on the chi router.
// Applying with within a r.Group() or using With() will not work without routes matching OPTIONS added.
r.Options(paths.Login, noopHandler)
r.Options(paths.Logout, noopHandler)
} else {
// This branch is necessary because middlewares must be applied before the routes.
r.Use(httpinternal.DisallowNonNavigationalRequests)
}
r.Get(paths.Login, src.Login)
r.Get(paths.Logout, src.Logout)