diff --git a/pkg/errorhandler/errorhandler.go b/pkg/errorhandler/errorhandler.go index 840d5e8..a6542fc 100644 --- a/pkg/errorhandler/errorhandler.go +++ b/pkg/errorhandler/errorhandler.go @@ -2,7 +2,7 @@ package errorhandler import ( "github.com/go-chi/chi/v5/middleware" - "github.com/nais/wonderwall/pkg/url" + "github.com/nais/wonderwall/pkg/request" "html/template" "net/http" @@ -21,7 +21,7 @@ func respondError(w http.ResponseWriter, r *http.Request, statusCode int, cause t, _ := template.ParseFiles("templates/error.html") errorPage := ErrorPage{ CorrelationID: middleware.GetReqID(r.Context()), - CanonicalRedirectURL: url.CanonicalRedirectURL(r), + CanonicalRedirectURL: request.CanonicalRedirectURL(r), } t.Execute(w, errorPage) } diff --git a/pkg/request/parameters.go b/pkg/request/parameters.go new file mode 100644 index 0000000..4478566 --- /dev/null +++ b/pkg/request/parameters.go @@ -0,0 +1,8 @@ +package request + +const ( + LocaleURLParameter = "locale" + PostLogoutRedirectURIParameter = "post_logout_redirect_uri" + RedirectURLParameter = "redirect" + SecurityLevelURLParameter = "level" +) diff --git a/pkg/router/request.go b/pkg/request/request.go similarity index 59% rename from pkg/router/request.go rename to pkg/request/request.go index f727052..4e272b0 100644 --- a/pkg/router/request.go +++ b/pkg/request/request.go @@ -1,16 +1,37 @@ -package router +package request import ( "errors" "fmt" "github.com/nais/wonderwall/pkg/config" "net/http" + "net/url" ) var ( InvalidLoginParameterError = errors.New("InvalidLoginParameter") ) +// CanonicalRedirectURL constructs a redirect URL that points back to the application. +func CanonicalRedirectURL(r *http.Request) string { + redirectURL := "/" + referer, err := url.Parse(r.Referer()) + if err == nil && len(referer.Path) > 0 { + redirectURL = referer.Path + } + override := r.URL.Query().Get(RedirectURLParameter) + if len(override) > 0 { + referer, err = url.Parse(override) + if err == nil { + // strip scheme and host to avoid cross-domain redirects + referer.Scheme = "" + referer.Host = "" + redirectURL = referer.String() + } + } + return redirectURL +} + // LoginURLParameter attempts to get a given parameter from the given HTTP request, falling back if none found. // The value must exist in the supplied list of supported values. func LoginURLParameter(r *http.Request, parameter, fallback string, supported config.Supported) (string, error) { diff --git a/pkg/router/request_test.go b/pkg/request/request_test.go similarity index 65% rename from pkg/router/request_test.go rename to pkg/request/request_test.go index 20514dd..3ac3fd7 100644 --- a/pkg/router/request_test.go +++ b/pkg/request/request_test.go @@ -1,15 +1,32 @@ -package router_test +package request_test import ( - "net/http" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/nais/wonderwall/pkg/config" - "github.com/nais/wonderwall/pkg/router" + "github.com/nais/wonderwall/pkg/request" + "github.com/stretchr/testify/assert" + "net/http" + "net/url" + "testing" ) +func TestCanonicalRedirectURL(t *testing.T) { + r, err := http.NewRequest("GET", "http://localhost:8080/oauth2/login", nil) + assert.NoError(t, err) + + // Default URL is / + assert.Equal(t, "/", request.CanonicalRedirectURL(r)) + + // HTTP Referer header is 2nd priority + r.Header.Set("referer", "http://localhost:8080/foo/bar/baz?gnu=notunix") + assert.Equal(t, "/foo/bar/baz", request.CanonicalRedirectURL(r)) + + // If redirect parameter is set, use that + v := &url.Values{} + v.Set("redirect", "https://google.com/path/to/redirect?val1=foo&val2=bar") + r.URL.RawQuery = v.Encode() + assert.Equal(t, "/path/to/redirect?val1=foo&val2=bar", request.CanonicalRedirectURL(r)) +} + func TestLoginURLParameter(t *testing.T) { for _, test := range []struct { name string @@ -38,19 +55,19 @@ func TestLoginURLParameter(t *testing.T) { { name: "invalid URL parameter value should return error", url: "http://localhost:8080/oauth2/login?param=invalid", - expectErr: router.InvalidLoginParameterError, + expectErr: request.InvalidLoginParameterError, }, { name: "invalid fallback value should return error", fallback: "invalid", url: "http://localhost:8080/oauth2/login", - expectErr: router.InvalidLoginParameterError, + expectErr: request.InvalidLoginParameterError, }, { name: "no supported values should return error", url: "http://localhost:8080/oauth2/login", supported: config.Supported{""}, - expectErr: router.InvalidLoginParameterError, + expectErr: request.InvalidLoginParameterError, }, } { t.Run(test.name, func(t *testing.T) { @@ -74,7 +91,7 @@ func TestLoginURLParameter(t *testing.T) { supported = test.supported } - val, err := router.LoginURLParameter(r, parameter, fallback, supported) + val, err := request.LoginURLParameter(r, parameter, fallback, supported) if test.expectErr == nil { assert.NoError(t, err) diff --git a/pkg/router/handler.go b/pkg/router/handler.go index efcaf03..ef3def0 100644 --- a/pkg/router/handler.go +++ b/pkg/router/handler.go @@ -12,12 +12,6 @@ import ( "github.com/nais/wonderwall/pkg/session" ) -const ( - SecurityLevelURLParameter = "level" - LocaleURLParameter = "locale" - PostLogoutRedirectURIParameter = "post_logout_redirect_uri" -) - type Handler struct { Config config.IDPorten Crypter cryptutil.Crypter diff --git a/pkg/router/handler_login.go b/pkg/router/handler_login.go index 1b2d7b6..91b4d9e 100644 --- a/pkg/router/handler_login.go +++ b/pkg/router/handler_login.go @@ -3,7 +3,7 @@ package router import ( "errors" "fmt" - "github.com/nais/wonderwall/pkg/url" + "github.com/nais/wonderwall/pkg/request" "net/http" "github.com/nais/wonderwall/pkg/auth" @@ -34,7 +34,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { State: params.State, Nonce: params.Nonce, CodeVerifier: params.CodeVerifier, - Referer: url.CanonicalRedirectURL(r), + Referer: request.CanonicalRedirectURL(r), }) if err != nil { errorhandler.InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err)) diff --git a/pkg/router/handler_logout.go b/pkg/router/handler_logout.go index 1e5d6c3..5feeaf7 100644 --- a/pkg/router/handler_logout.go +++ b/pkg/router/handler_logout.go @@ -2,6 +2,7 @@ package router import ( "fmt" + "github.com/nais/wonderwall/pkg/request" "net/http" "net/url" @@ -31,7 +32,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { h.deleteCookie(w, h.GetSessionCookieName()) v := u.Query() - v.Add("post_logout_redirect_uri", PostLogoutRedirectURI(r, h.Config.PostLogoutRedirectURI)) + v.Add("post_logout_redirect_uri", request.PostLogoutRedirectURI(r, h.Config.PostLogoutRedirectURI)) if len(idToken) != 0 { v.Add("id_token_hint", idToken) diff --git a/pkg/router/login_url.go b/pkg/router/login_url.go index 82e3245..19888d4 100644 --- a/pkg/router/login_url.go +++ b/pkg/router/login_url.go @@ -3,6 +3,7 @@ package router import ( "errors" "fmt" + "github.com/nais/wonderwall/pkg/request" "net/http" "net/url" @@ -55,7 +56,7 @@ func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error { fallback := h.Config.SecurityLevel.Value supported := h.Config.WellKnown.ACRValuesSupported - securityLevel, err := LoginURLParameter(r, SecurityLevelURLParameter, fallback, supported) + securityLevel, err := request.LoginURLParameter(r, request.SecurityLevelURLParameter, fallback, supported) if err != nil { return err } @@ -72,7 +73,7 @@ func (h *Handler) withLocale(r *http.Request, v url.Values) error { fallback := h.Config.Locale.Value supported := h.Config.WellKnown.UILocalesSupported - locale, err := LoginURLParameter(r, LocaleURLParameter, fallback, supported) + locale, err := request.LoginURLParameter(r, request.LocaleURLParameter, fallback, supported) if err != nil { return err } diff --git a/pkg/url/url.go b/pkg/url/url.go deleted file mode 100644 index 44a73f1..0000000 --- a/pkg/url/url.go +++ /dev/null @@ -1,28 +0,0 @@ -package url - -import ( - "net/http" - "net/url" -) - -const RedirectURLParameter = "redirect" - -// CanonicalRedirectURL constructs a redirect URL that points back to the application. -func CanonicalRedirectURL(r *http.Request) string { - redirectURL := "/" - referer, err := url.Parse(r.Referer()) - if err == nil && len(referer.Path) > 0 { - redirectURL = referer.Path - } - override := r.URL.Query().Get(RedirectURLParameter) - if len(override) > 0 { - referer, err = url.Parse(override) - if err == nil { - // strip scheme and host to avoid cross-domain redirects - referer.Scheme = "" - referer.Host = "" - redirectURL = referer.String() - } - } - return redirectURL -} diff --git a/pkg/url/url_test.go b/pkg/url/url_test.go deleted file mode 100644 index 885ee89..0000000 --- a/pkg/url/url_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package url - -import ( - "github.com/stretchr/testify/assert" - "net/http" - "net/url" - "testing" -) - -func TestCanonicalRedirectURL(t *testing.T) { - r, err := http.NewRequest("GET", "http://localhost:8080/oauth2/login", nil) - assert.NoError(t, err) - - // Default URL is / - assert.Equal(t, "/", CanonicalRedirectURL(r)) - - // HTTP Referer header is 2nd priority - r.Header.Set("referer", "http://localhost:8080/foo/bar/baz?gnu=notunix") - assert.Equal(t, "/foo/bar/baz", CanonicalRedirectURL(r)) - - // If redirect parameter is set, use that - v := &url.Values{} - v.Set("redirect", "https://google.com/path/to/redirect?val1=foo&val2=bar") - r.URL.RawQuery = v.Encode() - assert.Equal(t, "/path/to/redirect?val1=foo&val2=bar", CanonicalRedirectURL(r)) -}