From cc78d2195ba8ae2ecbfec554cfbb156066238f07 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Thu, 10 Mar 2022 11:03:01 +0100 Subject: [PATCH] fix: ensure canonical redirect URL is not empty --- pkg/router/request/request.go | 28 +++++++++++++------ pkg/router/request/request_test.go | 45 +++++++++++++++++++++++++++--- 2 files changed, 60 insertions(+), 13 deletions(-) diff --git a/pkg/router/request/request.go b/pkg/router/request/request.go index dd4fd43..25caff9 100644 --- a/pkg/router/request/request.go +++ b/pkg/router/request/request.go @@ -18,25 +18,35 @@ var ( // CanonicalRedirectURL constructs a redirect URL that points back to the application. func CanonicalRedirectURL(r *http.Request) string { + // 1. default redirectURL := "/" + // 2. Referer header is set referer := RefererPath(r) if len(referer) > 0 { redirectURL = referer } + // 3. redirect parameter is set override := r.URL.Query().Get(RedirectURLParameter) - if len(override) > 0 { - referer, err := url.Parse(override) - if err == nil { - // strip scheme and host to avoid cross-domain redirects - referer.Scheme = "" - referer.Host = "" - redirectURL = referer.String() - } + if len(override) <= 0 { + return redirectURL } - return redirectURL + overrideUrl, err := url.Parse(override) + if err != nil { + return redirectURL + } + + // strip scheme and host to avoid cross-domain redirects + overrideUrl.Scheme = "" + overrideUrl.Host = "" + overrideString := overrideUrl.String() + if len(overrideString) <= 0 { + return "/" + } + + return overrideString } // LoginURLParameter attempts to get a given parameter from the given HTTP request, falling back if none found. diff --git a/pkg/router/request/request_test.go b/pkg/router/request/request_test.go index db5aa4e..032646e 100644 --- a/pkg/router/request/request_test.go +++ b/pkg/router/request/request_test.go @@ -23,10 +23,47 @@ func TestCanonicalRedirectURL(t *testing.T) { assert.Equal(t, "/foo/bar/baz", request.CanonicalRedirectURL(r)) // If redirect parameter is set, use that - v := &url.Values{} - v.Set("redirect", "https://google.com/path/to/redirect?val1=foo&val2=bar") - r.URL.RawQuery = v.Encode() - assert.Equal(t, "/path/to/redirect?val1=foo&val2=bar", request.CanonicalRedirectURL(r)) + t.Run("redirect parameter", func(t *testing.T) { + for _, test := range []struct { + name string + value string + expected string + }{ + { + name: "complete url with parameters", + value: "https://google.com/path/to/redirect?val1=foo&val2=bar", + expected: "/path/to/redirect?val1=foo&val2=bar", + }, + { + name: "root url with trailing slash", + value: "https://google.com/", + expected: "/", + }, + { + name: "root url without trailing slash", + value: "https://google.com", + expected: "/", + }, + { + name: "url path with trailing slash", + value: "https://google.com/path/", + expected: "/path/", + }, + { + name: "url path without trailing slash", + value: "https://google.com/path", + expected: "/path", + }, + } { + t.Run(test.name, func(t *testing.T) { + v := &url.Values{} + v.Set("redirect", test.value) + r.URL.RawQuery = v.Encode() + assert.Equal(t, test.expected, request.CanonicalRedirectURL(r)) + }) + } + }) + } func TestLoginURLParameter(t *testing.T) {