diff --git a/pkg/handler/error/error_test.go b/pkg/handler/error/error_test.go index 1830d1e..003a794 100644 --- a/pkg/handler/error/error_test.go +++ b/pkg/handler/error/error_test.go @@ -4,11 +4,11 @@ import ( "fmt" "net/http" "net/http/httptest" - "net/url" "testing" "github.com/stretchr/testify/assert" + urlpkg "github.com/nais/wonderwall/pkg/handler/url" "github.com/nais/wonderwall/pkg/ingress" mw "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/mock" @@ -78,12 +78,12 @@ func TestHandler_Retry(t *testing.T) { { name: "login path", request: httpRequest("/oauth2/login"), - want: "/oauth2/login?redirect=" + url.QueryEscape("/"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/"), }, { name: "callback path", request: httpRequest("/oauth2/callback"), - want: "/oauth2/login?redirect=" + url.QueryEscape("/"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/"), }, { name: "logout path", @@ -99,7 +99,7 @@ func TestHandler_Retry(t *testing.T) { name: "login with non-default ingress", request: httpRequest("/domene/oauth2/login"), ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=" + url.QueryEscape("/domene"), + want: "/domene/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/domene"), }, { name: "logout with non-default ingress", @@ -110,103 +110,103 @@ func TestHandler_Retry(t *testing.T) { { name: "login with referer", request: httpRequest("/oauth2/login", "/api/me"), - want: "/oauth2/login?redirect=" + url.QueryEscape("/api/me"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/api/me"), }, { name: "login with referer on non-default ingress", request: httpRequest("/domene/oauth2/login", "/api/me"), ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=" + url.QueryEscape("/api/me"), + want: "/domene/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/api/me"), }, { name: "login with root referer", request: httpRequest("/oauth2/login", "/"), - want: "/oauth2/login?redirect=" + url.QueryEscape("/"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/"), }, { name: "login with root referer on non-default ingress", request: httpRequest("/domene/oauth2/login", "/"), ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=" + url.QueryEscape("/"), + want: "/domene/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/"), }, { name: "login with cookie referer", request: httpRequest("/oauth2/login"), loginCookie: &openid.LoginCookie{Referer: "/"}, - want: "/oauth2/login?redirect=" + url.QueryEscape("/"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/"), }, { name: "login with empty cookie referer", request: httpRequest("/oauth2/login"), loginCookie: &openid.LoginCookie{Referer: ""}, - want: "/oauth2/login?redirect=" + url.QueryEscape("/"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/"), }, { name: "login with cookie referer takes precedence over referer header", request: httpRequest("/oauth2/login", "/api/me"), loginCookie: &openid.LoginCookie{Referer: "/api/headers"}, - want: "/oauth2/login?redirect=" + url.QueryEscape("/api/headers"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/api/headers"), }, { name: "login with cookie referer on non-default ingress", request: httpRequest("/domene/oauth2/login"), loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"}, ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=" + url.QueryEscape("/domene/api/me"), + want: "/domene/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/domene/api/me"), }, { name: "login with redirect parameter set", request: httpRequest("/oauth2/login?redirect=/api/me"), - want: "/oauth2/login?redirect=" + url.QueryEscape("/api/me"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/api/me"), }, { name: "login with redirect parameter set and query parameters", request: httpRequest("/oauth2/login?redirect=/api/me?a=b%26c=d"), - want: "/oauth2/login?redirect=" + url.QueryEscape("/api/me?a=b&c=d"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/api/me?a=b&c=d"), }, { name: "login with redirect parameter set on non-default ingress", request: httpRequest("/domene/oauth2/login?redirect=/api/me"), ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=" + url.QueryEscape("/api/me"), + want: "/domene/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/api/me"), }, { name: "login with redirect parameter set takes precedence over referer header", request: httpRequest("/oauth2/login?redirect=/other", "/api/me"), - want: "/oauth2/login?redirect=" + url.QueryEscape("/other"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/other"), }, { name: "login with redirect parameter set to relative root takes precedence over referer header", request: httpRequest("/oauth2/login?redirect=/", "/api/me"), - want: "/oauth2/login?redirect=" + url.QueryEscape("/"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/"), }, { name: "login with redirect parameter set to relative root on non-default ingress takes precedence over referer header", request: httpRequest("/domene/oauth2/login?redirect=/", "/api/me"), ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=" + url.QueryEscape("/"), + want: "/domene/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/"), }, { name: "login with redirect parameter set to absolute url takes precedence over referer header", request: httpRequest("/oauth2/login?redirect=http://localhost:8080", "/api/me"), - want: "/oauth2/login?redirect=" + url.QueryEscape("/"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/"), }, { name: "login with redirect parameter set to absolute url with trailing slash takes precedence over referer header", request: httpRequest("/oauth2/login?redirect=http://localhost:8080/", "/api/me"), - want: "/oauth2/login?redirect=" + url.QueryEscape("/"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/"), }, { name: "login with redirect parameter set to absolute url on non-default ingress takes precedence over referer header", request: httpRequest("/domene/oauth2/login?redirect=http://localhost:8080/", "/api/me"), ingress: "https://test.nav.no/domene", - want: "/domene/oauth2/login?redirect=" + url.QueryEscape("/"), + want: "/domene/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/"), }, { name: "login with cookie referer takes precedence over redirect parameter", request: httpRequest("/oauth2/login?redirect=/other"), loginCookie: &openid.LoginCookie{Referer: "/domene/api/me"}, - want: "/oauth2/login?redirect=" + url.QueryEscape("/domene/api/me"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/domene/api/me"), }, } { t.Run(test.name, func(t *testing.T) { diff --git a/pkg/handler/handler_test.go b/pkg/handler/handler_test.go index e7881df..6f5e432 100644 --- a/pkg/handler/handler_test.go +++ b/pkg/handler/handler_test.go @@ -375,7 +375,7 @@ func TestHandler_Default(t *testing.T) { // redirect should point to local login endpoint loginLocation := resp.Location - assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login?redirect="+url.QueryEscape("/"), loginLocation.String()) + assert.Equal(t, idp.RelyingPartyServer.URL+"/oauth2/login?redirect-encoded="+urlpkg.RedirectEncoded("/"), loginLocation.String()) // follow redirect to local login endpoint resp = get(t, rpClient, loginLocation.String()) diff --git a/pkg/handler/url/url.go b/pkg/handler/url/url.go index b7ae8f6..5a6693a 100644 --- a/pkg/handler/url/url.go +++ b/pkg/handler/url/url.go @@ -1,6 +1,7 @@ package url import ( + "encoding/base64" "fmt" "net/http" "net/url" @@ -11,7 +12,8 @@ import ( ) const ( - RedirectURLParameter = "redirect" + RedirectURLParameter = "redirect" + RedirectURLEncodedParameter = "redirect-encoded" ) // CanonicalRedirect constructs a redirect URL that points back to the application. @@ -36,6 +38,21 @@ func CanonicalRedirect(r *http.Request) string { redirect = redirectParam } + // 4. Redirect-encoded parameter is set + redirectEncodedParam := r.URL.Query().Get(RedirectURLEncodedParameter) + if len(redirectEncodedParam) > 0 { + decodedBytes, err := base64.RawURLEncoding.DecodeString(redirectEncodedParam) + if err == nil { + redirect = string(decodedBytes) + } + } + + // Ensure URL isn't encoded + redirect, err := url.QueryUnescape(redirect) + if err != nil { + return ingressPath + } + parsed, err := url.Parse(redirect) if err != nil { // Silently fall back to ingress path @@ -48,12 +65,6 @@ func CanonicalRedirect(r *http.Request) string { redirect = parsed.String() - // Ensure URL isn't encoded - redirect, err = url.QueryUnescape(redirect) - if err != nil { - return ingressPath - } - // Root path without trailing slash is empty if len(parsed.Path) == 0 { redirect = "/" @@ -72,12 +83,16 @@ func LoginURL(prefix, redirectTarget string) string { u.Path = path.Join(prefix, paths.OAuth2, paths.Login) v := url.Values{} - v.Set(RedirectURLParameter, redirectTarget) + v.Set(RedirectURLEncodedParameter, RedirectEncoded(redirectTarget)) u.RawQuery = v.Encode() return u.String() } +func RedirectEncoded(s string) string { + return base64.RawURLEncoding.EncodeToString([]byte(s)) +} + func LoginCallbackURL(r *http.Request) (string, error) { return makeCallbackURL(r, paths.LoginCallback) } diff --git a/pkg/handler/url/url_test.go b/pkg/handler/url/url_test.go index 32b0345..35c7be2 100644 --- a/pkg/handler/url/url_test.go +++ b/pkg/handler/url/url_test.go @@ -95,7 +95,7 @@ func TestCanonicalRedirect(t *testing.T) { } }) - // If redirect parameter is set, use that + // If either redirect or redirect-encoded parameter is set, use that t.Run("redirect parameter is set", func(t *testing.T) { for _, test := range []struct { name string @@ -150,12 +150,12 @@ func TestCanonicalRedirect(t *testing.T) { { name: "url encoded url", value: "http%3A%2F%2Flocalhost%3A8080%2Fpath", - expected: "http://localhost:8080/path", + expected: "/path", }, { name: "url encoded url and multiple query parameters", value: "http%3A%2F%2Flocalhost%3A8080%2Fpath%3Fgnu%3Dnotunix%26foo%3Dbar", - expected: "http://localhost:8080/path?gnu=notunix&foo=bar", + expected: "/path?gnu=notunix&foo=bar", }, } { t.Run(test.name, func(t *testing.T) { @@ -164,6 +164,13 @@ func TestCanonicalRedirect(t *testing.T) { r.URL.RawQuery = v.Encode() assert.Equal(t, test.expected, urlpkg.CanonicalRedirect(r)) }) + + t.Run(test.name+" encoded", func(t *testing.T) { + v := &url.Values{} + v.Set("redirect-encoded", urlpkg.RedirectEncoded(test.value)) + r.URL.RawQuery = v.Encode() + assert.Equal(t, test.expected, urlpkg.CanonicalRedirect(r)) + }) } }) } @@ -179,25 +186,25 @@ func TestLoginURL(t *testing.T) { name: "no prefix", prefix: "", redirectTarget: "https://test.example.com?some=param&other=param2", - want: "/oauth2/login?redirect=" + url.QueryEscape("https://test.example.com?some=param&other=param2"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("https://test.example.com?some=param&other=param2"), }, { name: "with prefix", prefix: "/path", redirectTarget: "https://test.example.com?some=param&other=param2", - want: "/path/oauth2/login?redirect=" + url.QueryEscape("https://test.example.com?some=param&other=param2"), + want: "/path/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("https://test.example.com?some=param&other=param2"), }, { name: "we need to go deeper", prefix: "/deeper/path", redirectTarget: "https://test.example.com?some=param&other=param2", - want: "/deeper/path/oauth2/login?redirect=" + url.QueryEscape("https://test.example.com?some=param&other=param2"), + want: "/deeper/path/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("https://test.example.com?some=param&other=param2"), }, { name: "relative target", prefix: "", redirectTarget: "/path?some=param&other=param2", - want: "/oauth2/login?redirect=" + url.QueryEscape("/path?some=param&other=param2"), + want: "/oauth2/login?redirect-encoded=" + urlpkg.RedirectEncoded("/path?some=param&other=param2"), }, } { t.Run(test.name, func(t *testing.T) {