diff --git a/pkg/url/redirect.go b/pkg/url/redirect.go index 4444e1f..0b9ee9d 100644 --- a/pkg/url/redirect.go +++ b/pkg/url/redirect.go @@ -21,13 +21,13 @@ var _ Redirect = &StandaloneRedirect{} type StandaloneRedirect struct { ingresses *ingress.Ingresses - validator *Validator + *cleaner } func NewStandaloneRedirect(ingresses *ingress.Ingresses) *StandaloneRedirect { return &StandaloneRedirect{ ingresses: ingresses, - validator: NewValidator(Relative, ingresses.Hosts()), + cleaner: newRelativeCleaner(ingresses.Hosts()), } } @@ -46,11 +46,7 @@ func (h *StandaloneRedirect) Canonical(r *http.Request) string { } func (h *StandaloneRedirect) Clean(r *http.Request, target string) string { - if h.validator.IsValidRedirect(r, target) { - return target - } - - return fallback(r, target, h.FallbackRedirect(r)).String() + return h.clean(r, target, h.FallbackRedirect(r)) } func (h *StandaloneRedirect) FallbackRedirect(r *http.Request) *url.URL { @@ -61,7 +57,7 @@ var _ Redirect = &SSOServerRedirect{} type SSOServerRedirect struct { fallbackRedirect *url.URL - validator *Validator + *cleaner } func NewSSOServerRedirect(config *config.Config) (*SSOServerRedirect, error) { @@ -72,7 +68,7 @@ func NewSSOServerRedirect(config *config.Config) (*SSOServerRedirect, error) { return &SSOServerRedirect{ fallbackRedirect: u, - validator: NewValidator(Absolute, []string{config.SSO.Domain}), + cleaner: newAbsoluteCleaner([]string{config.SSO.Domain}), }, nil } @@ -87,24 +83,20 @@ func (h *SSOServerRedirect) Canonical(r *http.Request) string { } func (h *SSOServerRedirect) Clean(r *http.Request, target string) string { - if h.validator.IsValidRedirect(r, target) { - return target - } - - return fallback(r, target, h.fallbackRedirect).String() + return h.clean(r, target, h.fallbackRedirect) } var _ Redirect = &SSOProxyRedirect{} type SSOProxyRedirect struct { fallbackRedirect *url.URL - validator *Validator + *cleaner } func NewSSOProxyRedirect(ingresses *ingress.Ingresses) *SSOProxyRedirect { return &SSOProxyRedirect{ fallbackRedirect: ingresses.Single().NewURL(), - validator: NewValidator(Absolute, ingresses.Hosts()), + cleaner: newAbsoluteCleaner(ingresses.Hosts()), } } @@ -130,11 +122,7 @@ func (h *SSOProxyRedirect) Canonical(r *http.Request) string { } func (h *SSOProxyRedirect) Clean(r *http.Request, target string) string { - if h.validator.IsValidRedirect(r, target) { - return target - } - - return fallback(r, target, h.getFallbackRedirect()).String() + return h.clean(r, target, h.getFallbackRedirect()) } // getFallbackRedirect returns a copy of the configured fallbackRedirect @@ -143,6 +131,30 @@ func (h *SSOProxyRedirect) getFallbackRedirect() *url.URL { return &u } +type cleaner struct { + Validator +} + +func newAbsoluteCleaner(allowedHosts []string) *cleaner { + return &cleaner{ + Validator: NewAbsoluteValidator(allowedHosts), + } +} + +func newRelativeCleaner(allowedHosts []string) *cleaner { + return &cleaner{ + Validator: NewRelativeValidator(allowedHosts), + } +} + +func (c *cleaner) clean(r *http.Request, target string, fallbackTarget *url.URL) string { + if c.IsValidRedirect(r, target) { + return target + } + + return fallback(r, target, fallbackTarget).String() +} + func redirectQueryParam(r *http.Request) string { return r.URL.Query().Get(RedirectQueryParameter) } diff --git a/pkg/url/validator.go b/pkg/url/validator.go index 27d4bcb..81e81d7 100644 --- a/pkg/url/validator.go +++ b/pkg/url/validator.go @@ -13,53 +13,24 @@ import ( // Matches //, /\ and both of these with whitespace in between (eg / / or / \). var invalidRedirectRegex = regexp.MustCompile(`[/\\](?:[\s\v]*|\.{1,2})[/\\]`) -type Validator struct { +var _ Validator = &AbsoluteValidator{} + +type Validator interface { + IsValidRedirect(r *http.Request, redirect string) bool +} + +type AbsoluteValidator struct { allowedDomains []string - urlType Type } -type Type int - -const ( - Relative Type = iota - Absolute -) - -func NewValidator(urlType Type, allowedDomains []string) *Validator { - return &Validator{urlType: urlType, allowedDomains: allowedDomains} +func NewAbsoluteValidator(allowedDomains []string) *AbsoluteValidator { + return &AbsoluteValidator{allowedDomains: allowedDomains} } -func (v *Validator) IsValidRedirect(r *http.Request, redirect string) bool { - switch v.urlType { - case Absolute: - return v.isValidAbsoluteRedirect(r, redirect) - case Relative: - return v.isValidRelativeRedirect(r, redirect) - default: - return v.isValidAbsoluteRedirect(r, redirect) - } -} - -// isValidRelativeRedirect validates that the given redirect string is a valid relative URL. -// It must be an absolute path (i.e. has a leading '/'). -func (v *Validator) isValidRelativeRedirect(r *http.Request, redirect string) bool { - u, ok := parsableRequestURI(r, redirect) - if !ok { - return false - } - - if isRelativeURL(u) && isValidAbsolutePath(u.String()) { - return true - } - - mw.LogEntryFrom(r).Infof("validator: not a valid relative URL") - return false -} - -// isValidAbsoluteRedirect validates that the given redirect string is a valid absolute URL. +// IsValidRedirect validates that the given redirect string is a valid absolute URL. // It must use the 'http' or 'https' scheme. // It must point to a host that matches the configured list of allowed domains. -func (v *Validator) isValidAbsoluteRedirect(r *http.Request, redirect string) bool { +func (v *AbsoluteValidator) IsValidRedirect(r *http.Request, redirect string) bool { u, ok := parsableRequestURI(r, redirect) if !ok { return false @@ -87,6 +58,32 @@ func (v *Validator) isValidAbsoluteRedirect(r *http.Request, redirect string) bo return false } +var _ Validator = &RelativeValidator{} + +type RelativeValidator struct { + allowedDomains []string +} + +func NewRelativeValidator(allowedDomains []string) *RelativeValidator { + return &RelativeValidator{allowedDomains: allowedDomains} +} + +// IsValidRedirect validates that the given redirect string is a valid relative URL. +// It must be an absolute path (i.e. has a leading '/'). +func (v *RelativeValidator) IsValidRedirect(r *http.Request, redirect string) bool { + u, ok := parsableRequestURI(r, redirect) + if !ok { + return false + } + + if isRelativeURL(u) && isValidAbsolutePath(u.String()) { + return true + } + + mw.LogEntryFrom(r).Infof("validator: not a valid relative URL") + return false +} + func parsableRequestURI(r *http.Request, redirect string) (*url.URL, bool) { if redirect == "" { mw.LogEntryFrom(r).Infof("validator: redirect is empty") diff --git a/pkg/url/validator_test.go b/pkg/url/validator_test.go index ec7dd25..c238f65 100644 --- a/pkg/url/validator_test.go +++ b/pkg/url/validator_test.go @@ -14,7 +14,7 @@ import ( urlpkg "github.com/nais/wonderwall/pkg/url" ) -func TestValidator_IsValidRedirect(t *testing.T) { +func TestAbsoluteValidator_IsValidRedirect(t *testing.T) { cfg := mock.Config() cfg.SSO.Domain = "wonderwall" ingresses := mock.Ingresses(cfg) @@ -24,8 +24,7 @@ func TestValidator_IsValidRedirect(t *testing.T) { cfg.SSO.Domain, "www.whitelisteddomain.tld", } - absoluteValidator := urlpkg.NewValidator(urlpkg.Absolute, allowedDomains) - relativeValidator := urlpkg.NewValidator(urlpkg.Relative, allowedDomains) + absoluteValidator := urlpkg.NewAbsoluteValidator(allowedDomains) t.Run("open redirects list", func(t *testing.T) { file, err := os.Open("testdata/open-redirects.txt") @@ -36,6 +35,123 @@ func TestValidator_IsValidRedirect(t *testing.T) { for scanner.Scan() { input := url.QueryEscape(scanner.Text()) assert.False(t, absoluteValidator.IsValidRedirect(r, input), fmt.Sprintf("%q should not pass validation", input)) + } + + err = scanner.Err() + require.NoError(t, err) + }) + + for _, tt := range []struct { + name string + redirectParam string + wantErr bool + }{ + { + name: "absolute url with parameters", + redirectParam: "https://wonderwall/path/to/redirect?val1=foo&val2=bar", + }, + { + name: "absolute url with http scheme", + redirectParam: "https://wonderwall/path/to/redirect?val1=foo&val2=bar", + }, + { + name: "absolute url with non-http scheme", + redirectParam: "ftp://wonderwall/path/to/redirect?val1=foo&val2=bar", + wantErr: true, + }, + { + name: "root url with trailing slash", + redirectParam: "https://wonderwall/", + }, + { + name: "root url without trailing slash", + redirectParam: "https://wonderwall", + }, + { + name: "url path with trailing slash", + redirectParam: "https://wonderwall/path/", + }, + { + name: "url path without trailing slash", + redirectParam: "https://wonderwall/path", + }, + { + name: "different domain", + redirectParam: "https://not-wonderwall/path/to/redirect?val1=foo&val2=bar", + wantErr: true, + }, + { + name: "absolute path", + redirectParam: "/path", + wantErr: true, + }, + { + name: "absolute path with query parameters", + redirectParam: "/path?gnu=notunix", + wantErr: true, + }, + { + name: "relative path", + redirectParam: "path", + wantErr: true, + }, + { + name: "relative path with query parameters", + redirectParam: "path?gnu=notunix", + wantErr: true, + }, + { + name: "double-url encoded path", + redirectParam: "%2Fpath", + wantErr: true, + }, + { + name: "double-url encoded path and query parameters", + redirectParam: "%2Fpath%3Fgnu%3Dnotunix", + wantErr: true, + }, + { + name: "double-url encoded url", + redirectParam: "http%3A%2F%2Flocalhost%3A8080%2Fpath", + wantErr: true, + }, + { + name: "double-url encoded url and multiple query parameters", + redirectParam: "http%3A%2F%2Flocalhost%3A8080%2Fpath%3Fgnu%3Dnotunix%26foo%3Dbar", + wantErr: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + actual := absoluteValidator.IsValidRedirect(r, tt.redirectParam) + if tt.wantErr { + assert.False(t, actual) + } else { + assert.True(t, actual) + } + }) + } +} + +func TestRelativeValidator_IsValidRedirect(t *testing.T) { + cfg := mock.Config() + cfg.SSO.Domain = "wonderwall" + ingresses := mock.Ingresses(cfg) + r := mock.NewGetRequest("https://wonderwall", ingresses) + + allowedDomains := []string{ + cfg.SSO.Domain, + "www.whitelisteddomain.tld", + } + relativeValidator := urlpkg.NewRelativeValidator(allowedDomains) + + t.Run("open redirects list", func(t *testing.T) { + file, err := os.Open("testdata/open-redirects.txt") + require.NoError(t, err) + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + input := url.QueryEscape(scanner.Text()) assert.False(t, relativeValidator.IsValidRedirect(r, input), fmt.Sprintf("%q should not pass validation", input)) } @@ -46,116 +162,93 @@ func TestValidator_IsValidRedirect(t *testing.T) { for _, tt := range []struct { name string redirectParam string - urlType urlpkg.Type - wantErr bool // expect error regardless of validator type + wantErr bool }{ { name: "absolute url with parameters", redirectParam: "https://wonderwall/path/to/redirect?val1=foo&val2=bar", - urlType: urlpkg.Absolute, + wantErr: true, }, { name: "absolute url with http scheme", redirectParam: "https://wonderwall/path/to/redirect?val1=foo&val2=bar", - urlType: urlpkg.Absolute, + wantErr: true, }, { name: "absolute url with non-http scheme", redirectParam: "ftp://wonderwall/path/to/redirect?val1=foo&val2=bar", - urlType: urlpkg.Absolute, wantErr: true, }, { name: "root url with trailing slash", redirectParam: "https://wonderwall/", - urlType: urlpkg.Absolute, + wantErr: true, }, { name: "root url without trailing slash", redirectParam: "https://wonderwall", - urlType: urlpkg.Absolute, + wantErr: true, }, { name: "url path with trailing slash", redirectParam: "https://wonderwall/path/", - urlType: urlpkg.Absolute, + wantErr: true, }, { name: "url path without trailing slash", redirectParam: "https://wonderwall/path", - urlType: urlpkg.Absolute, + wantErr: true, }, { name: "different domain", redirectParam: "https://not-wonderwall/path/to/redirect?val1=foo&val2=bar", - urlType: urlpkg.Absolute, wantErr: true, }, { name: "absolute path", redirectParam: "/path", - urlType: urlpkg.Relative, }, { name: "absolute path with query parameters", redirectParam: "/path?gnu=notunix", - urlType: urlpkg.Relative, }, { name: "relative path", redirectParam: "path", - urlType: urlpkg.Relative, wantErr: true, }, { name: "relative path with query parameters", redirectParam: "path?gnu=notunix", - urlType: urlpkg.Relative, wantErr: true, }, { name: "double-url encoded path", redirectParam: "%2Fpath", - urlType: urlpkg.Relative, wantErr: true, }, { name: "double-url encoded path and query parameters", redirectParam: "%2Fpath%3Fgnu%3Dnotunix", - urlType: urlpkg.Relative, wantErr: true, }, { name: "double-url encoded url", redirectParam: "http%3A%2F%2Flocalhost%3A8080%2Fpath", - urlType: urlpkg.Relative, wantErr: true, }, { name: "double-url encoded url and multiple query parameters", redirectParam: "http%3A%2F%2Flocalhost%3A8080%2Fpath%3Fgnu%3Dnotunix%26foo%3Dbar", - urlType: urlpkg.Relative, wantErr: true, }, } { t.Run(tt.name, func(t *testing.T) { - switch tt.urlType { - case urlpkg.Relative: - actual := relativeValidator.IsValidRedirect(r, tt.redirectParam) - if tt.wantErr { - assert.False(t, actual) - } else { - assert.True(t, actual) - } - case urlpkg.Absolute: - actual := absoluteValidator.IsValidRedirect(r, tt.redirectParam) - if tt.wantErr { - assert.False(t, actual) - } else { - assert.True(t, actual) - } - default: - assert.FailNow(t, "invalid url type: %s", tt.urlType) + actual := relativeValidator.IsValidRedirect(r, tt.redirectParam) + if tt.wantErr { + assert.False(t, actual) + } else { + assert.True(t, actual) } }) }