diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 7121ea2..74fa0b7 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -39,7 +39,6 @@ type Standalone struct { CookieOptions cookie.Options Crypter crypto.Crypter Ingresses *ingress.Ingresses - OpenidConfig openidconfig.Config Redirect url.Redirect SessionManager session.Manager UpstreamProxy *ReverseProxy @@ -86,7 +85,6 @@ func NewStandalone( CookieOptions: cookieOpts, Crypter: crypter, Ingresses: ingresses, - OpenidConfig: openidConfig, Redirect: url.NewStandaloneRedirect(ingresses), SessionManager: sessionManager, UpstreamProxy: NewReverseProxy(upstream, true), @@ -255,7 +253,21 @@ func (s *Standalone) logout(w http.ResponseWriter, r *http.Request, globalLogout cookie.Clear(w, cookie.Session, s.GetCookieOptions(r)) if globalLogout { - logger.Debug("logout: redirecting to identity provider for global/single-logout") + // only set a canonical redirect if it was provided in the request as a query parameter + canonicalRedirect := r.URL.Query().Get(url.RedirectQueryParameter) + if canonicalRedirect != "" { + canonicalRedirect = s.Redirect.Canonical(r) + } + + opts := s.CookieOptions.WithExpiresIn(5 * time.Minute) + err = logout.SetCookie(w, opts, s.Crypter, canonicalRedirect) + if err != nil { + s.InternalError(w, r, fmt.Errorf("logout: setting logout cookie: %w", err)) + return + } + + logger.WithField("redirect_after_logout", canonicalRedirect). + Info("logout: redirecting to identity provider for global/single-logout") metrics.ObserveLogout(metrics.LogoutOperationSelfInitiated) http.Redirect(w, r, logout.SingleLogoutURL(idToken), http.StatusFound) } else { @@ -266,10 +278,19 @@ func (s *Standalone) logout(w http.ResponseWriter, r *http.Request, globalLogout } func (s *Standalone) LogoutCallback(w http.ResponseWriter, r *http.Request) { - redirect := s.Client.LogoutCallback(r).PostLogoutRedirectURI() + logger := mw.LogEntryFrom(r) + cookie.Clear(w, cookie.Logout, s.CookieOptions) + + logoutCookie, err := openid.GetLogoutCookie(r, s.Crypter) + if err != nil { + logger.Debugf("logout/callback: getting cookie: %+v; ignoring...", err) + } + + logoutCallback := s.Client.LogoutCallback(r, logoutCookie, s.Redirect.GetValidator()) + redirect := logoutCallback.PostLogoutRedirectURI() cookie.Clear(w, cookie.Retry, s.GetCookieOptions(r)) - mw.LogEntryFrom(r).Debugf("logout/callback: redirecting to %s", redirect) + logger.Infof("logout/callback: redirecting to %q", redirect) http.Redirect(w, r, redirect, http.StatusFound) } diff --git a/pkg/handler/handler_sso_proxy.go b/pkg/handler/handler_sso_proxy.go index 859db74..e10e3b6 100644 --- a/pkg/handler/handler_sso_proxy.go +++ b/pkg/handler/handler_sso_proxy.go @@ -141,8 +141,21 @@ func (s *SSOProxy) LoginCallback(w http.ResponseWriter, r *http.Request) { } func (s *SSOProxy) Logout(w http.ResponseWriter, r *http.Request) { - target := s.GetSSOServerURL().JoinPath(paths.OAuth2, paths.Logout) - http.Redirect(w, r, target.String(), http.StatusFound) + target := s.GetSSOServerURL() + + // only set a canonical redirect if it was provided in the request as a query parameter + canonicalRedirect := r.URL.Query().Get(url.RedirectQueryParameter) + if canonicalRedirect != "" { + canonicalRedirect = s.Redirect.Canonical(r) + } + ssoServerLogoutURL := url.Logout(target, canonicalRedirect) + + logentry.LogEntryFrom(r).WithFields(log.Fields{ + "redirect_to": ssoServerLogoutURL, + "redirect_after_logout": canonicalRedirect, + }).Info("logout: redirecting to sso server") + + http.Redirect(w, r, ssoServerLogoutURL, http.StatusFound) } func (s *SSOProxy) LogoutCallback(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/handler/handler_test.go b/pkg/handler/handler_test.go index 4dad1f1..048413e 100644 --- a/pkg/handler/handler_test.go +++ b/pkg/handler/handler_test.go @@ -90,24 +90,7 @@ func TestLogout(t *testing.T) { rpClient := idp.RelyingPartyClient() login(t, rpClient, idp) - - resp := selfInitiatedLogout(t, rpClient, idp) - - // Get endsession endpoint after local logout - endsessionURL := resp.Location - - idpserverURL, err := url.Parse(idp.ProviderServer.URL) - assert.NoError(t, err) - - req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout/callback") - expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req) - assert.NoError(t, err) - - endsessionParams := endsessionURL.Query() - assert.Equal(t, idpserverURL.Host, endsessionURL.Host) - assert.Equal(t, "/endsession", endsessionURL.Path) - assert.Equal(t, []string{expectedLogoutCallbackURL}, endsessionParams["post_logout_redirect_uri"]) - assert.NotEmpty(t, endsessionParams["id_token_hint"]) + selfInitiatedLogout(t, rpClient, idp) } func TestLogoutLocal(t *testing.T) { @@ -131,6 +114,18 @@ func TestLogoutCallback(t *testing.T) { logout(t, rpClient, idp) } +func TestLogoutCallback_WithRedirect(t *testing.T) { + cfg := mock.Config() + idp := mock.NewIdentityProvider(cfg) + defer idp.Close() + + redirect := idp.RelyingPartyServer.URL + "/api/me" + + rpClient := idp.RelyingPartyClient() + login(t, rpClient, idp) + logout(t, rpClient, idp, redirect) +} + func TestFrontChannelLogout(t *testing.T) { cfg := mock.Config() idp := mock.NewIdentityProvider(cfg) @@ -485,11 +480,17 @@ func login(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) *htt return callback(t, rpClient, resp) } -func selfInitiatedLogout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response { +func selfInitiatedLogout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider, redirectAfterLogout ...string) response { // Request self-initiated logout logoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout") assert.NoError(t, err) + if len(redirectAfterLogout) > 0 { + v := url.Values{} + v.Set(urlpkg.RedirectQueryParameter, redirectAfterLogout[0]) + logoutURL.RawQuery = v.Encode() + } + resp := get(t, rpClient, logoutURL.String()) assert.Equal(t, http.StatusFound, resp.StatusCode) @@ -498,21 +499,44 @@ func selfInitiatedLogout(t *testing.T, rpClient *http.Client, idp *mock.Identity assert.Nil(t, sessionCookie) + // Get endsession endpoint after local logout + endsessionURL := resp.Location + + idpserverURL, err := url.Parse(idp.ProviderServer.URL) + assert.NoError(t, err) + + req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout") + expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req) + assert.NoError(t, err) + + endsessionParams := endsessionURL.Query() + assert.Equal(t, idpserverURL.Host, endsessionURL.Host) + assert.Equal(t, "/endsession", endsessionURL.Path) + assert.Equal(t, expectedLogoutCallbackURL, endsessionParams.Get("post_logout_redirect_uri")) + assert.NotEmpty(t, endsessionParams.Get("id_token_hint")) + assert.NotEmpty(t, endsessionParams.Get("state")) + return resp } -func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) { +func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider, redirectAfterLogout ...string) { // Get endsession endpoint after local logout - resp := selfInitiatedLogout(t, rpClient, idp) + resp := selfInitiatedLogout(t, rpClient, idp, redirectAfterLogout...) + expectedState := resp.Location.Query().Get("state") // Follow redirect to endsession endpoint at identity provider resp = get(t, rpClient, resp.Location.String()) assert.Equal(t, http.StatusFound, resp.StatusCode) - // Get post-logout redirect URI after successful logout at identity provider logoutCallbackURI := resp.Location - req := idp.GetRequest(resp.Location.String()) + // Assert state for callback equals state sent in initial logout request + actualState := logoutCallbackURI.Query().Get("state") + assert.NotEmpty(t, actualState) + assert.Equal(t, expectedState, actualState) + + // Assert post-logout redirect URI after successful logout at identity provider + req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout") expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req) assert.NoError(t, err) @@ -521,10 +545,15 @@ func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) { // Follow redirect back to logout callback resp = get(t, rpClient, logoutCallbackURI.String()) - assert.Equal(t, http.StatusFound, resp.StatusCode) // Get post-logout redirect URI after redirect back to logout callback - assert.Equal(t, "https://google.com", resp.Location.String()) + assert.Equal(t, http.StatusFound, resp.StatusCode) + + expectedRedirect := "https://google.com" + if len(redirectAfterLogout) > 0 { + expectedRedirect = redirectAfterLogout[0] + } + assert.Equal(t, expectedRedirect, resp.Location.String()) cookies := rpClient.Jar.Cookies(logoutCallbackURI) sessionCookie := getCookieFromJar(cookie.Session, cookies) diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index 1c383fc..99b6e0e 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -558,10 +558,17 @@ func (ip *IdentityProviderHandler) validateClientAuthentication(w http.ResponseW func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() postLogoutRedirectURI := query.Get("post_logout_redirect_uri") + state := query.Get("state") if postLogoutRedirectURI == "" { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("missing required parameters")) + w.Write([]byte("missing required 'post_logout_redirect_uri' parameter")) + return + } + + if state == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("missing required 'state' parameter")) return } @@ -572,6 +579,10 @@ func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Req return } + v := url.Values{} + v.Set("state", state) + u.RawQuery = v.Encode() + http.Redirect(w, r, u.String(), http.StatusFound) } diff --git a/pkg/openid/client/client.go b/pkg/openid/client/client.go index de78527..dc98b5d 100644 --- a/pkg/openid/client/client.go +++ b/pkg/openid/client/client.go @@ -18,6 +18,7 @@ import ( "github.com/nais/wonderwall/pkg/openid" openidconfig "github.com/nais/wonderwall/pkg/openid/config" + urlpkg "github.com/nais/wonderwall/pkg/url" ) var ( @@ -91,8 +92,8 @@ func (c *Client) Logout(r *http.Request) (*Logout, error) { return logout, nil } -func (c *Client) LogoutCallback(r *http.Request) *LogoutCallback { - return NewLogoutCallback(c, r) +func (c *Client) LogoutCallback(r *http.Request, cookie *openid.LogoutCookie, validator urlpkg.Validator) *LogoutCallback { + return NewLogoutCallback(c, r, cookie, validator) } func (c *Client) LogoutFrontchannel(r *http.Request) *LogoutFrontchannel { @@ -182,3 +183,15 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid return &tokenResponse, nil } + +func StateMismatchError(expectedState, actualState string) error { + if len(actualState) <= 0 { + return fmt.Errorf("missing state parameter in request (possible csrf)") + } + + if expectedState != actualState { + return fmt.Errorf("state parameter mismatch (possible csrf): expected %s, got %s", expectedState, actualState) + } + + return nil +} diff --git a/pkg/openid/client/client_test.go b/pkg/openid/client/client_test.go index b6828e3..15c3ec8 100644 --- a/pkg/openid/client/client_test.go +++ b/pkg/openid/client/client_test.go @@ -45,6 +45,22 @@ func TestMakeAssertion(t *testing.T) { assert.True(t, assertion.Expiration().Before(time.Now().Add(expiry))) } +func TestStateMismatchError(t *testing.T) { + for _, tt := range []struct { + name, expected, actual string + assertion assert.ErrorAssertionFunc + }{ + {"missing actual state", "expected", "", assert.Error}, + {"state mismatch", "match", "not-match", assert.Error}, + {"state match", "match", "match", assert.NoError}, + } { + t.Run(tt.name, func(t *testing.T) { + err := client.StateMismatchError(tt.expected, tt.actual) + tt.assertion(t, err) + }) + } +} + func newTestClientWithConfig(config *mock.TestConfiguration) *client.Client { jwksProvider := mock.NewTestJwksProvider() return client.NewClient(config, jwksProvider) diff --git a/pkg/openid/client/login_callback.go b/pkg/openid/client/login_callback.go index ca1d585..e6b80b5 100644 --- a/pkg/openid/client/login_callback.go +++ b/pkg/openid/client/login_callback.go @@ -15,7 +15,6 @@ import ( type LoginCallback struct { *Client cookie *openid.LoginCookie - request *http.Request requestParams url.Values } @@ -37,7 +36,6 @@ func NewLoginCallback(c *Client, r *http.Request, cookie *openid.LoginCookie) (* return &LoginCallback{ Client: c, cookie: cookie, - request: r, requestParams: r.URL.Query(), }, nil } @@ -56,15 +54,7 @@ func (in *LoginCallback) StateMismatchError() error { expectedState := in.cookie.State actualState := in.requestParams.Get(openid.State) - if len(actualState) <= 0 { - return fmt.Errorf("missing state parameter in request (possible csrf)") - } - - if expectedState != actualState { - return fmt.Errorf("state parameter mismatch (possible csrf): expected %s, got %s", expectedState, actualState) - } - - return nil + return StateMismatchError(expectedState, actualState) } func (in *LoginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, error) { diff --git a/pkg/openid/client/logout.go b/pkg/openid/client/logout.go index 4c9fa88..b7b32ec 100644 --- a/pkg/openid/client/logout.go +++ b/pkg/openid/client/logout.go @@ -1,16 +1,20 @@ package client import ( + "encoding/json" "fmt" "net/http" + "github.com/nais/wonderwall/pkg/cookie" + "github.com/nais/wonderwall/pkg/crypto" "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/strings" urlpkg "github.com/nais/wonderwall/pkg/url" ) type Logout struct { *Client - request *http.Request + Cookie *openid.LogoutCookie logoutCallbackURL string } @@ -20,10 +24,19 @@ func NewLogout(c *Client, r *http.Request) (*Logout, error) { return nil, fmt.Errorf("generating logout callback url: %w", err) } + state, err := strings.GenerateBase64(32) + if err != nil { + return nil, fmt.Errorf("generating state: %w", err) + } + + logoutCookie := &openid.LogoutCookie{ + State: state, + } + return &Logout{ Client: c, + Cookie: logoutCookie, logoutCallbackURL: logoutCallbackURL, - request: r, }, nil } @@ -31,6 +44,7 @@ func (in *Logout) SingleLogoutURL(idToken string) string { endSessionEndpoint := in.cfg.Provider().EndSessionEndpointURL() v := endSessionEndpoint.Query() v.Add(openid.PostLogoutRedirectURI, in.logoutCallbackURL) + v.Add(openid.State, in.Cookie.State) if len(idToken) > 0 { v.Add(openid.IDTokenHint, idToken) @@ -39,3 +53,15 @@ func (in *Logout) SingleLogoutURL(idToken string) string { endSessionEndpoint.RawQuery = v.Encode() return endSessionEndpoint.String() } + +func (in *Logout) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter, canonicalRedirect string) error { + in.Cookie.RedirectTo = canonicalRedirect + + logoutCookieJson, err := json.Marshal(in.Cookie) + if err != nil { + return fmt.Errorf("marshalling logout cookie: %w", err) + } + + value := string(logoutCookieJson) + return cookie.EncryptAndSet(w, cookie.Logout, value, opts, crypter) +} diff --git a/pkg/openid/client/logout_callback.go b/pkg/openid/client/logout_callback.go index 73a541f..b4e1684 100644 --- a/pkg/openid/client/logout_callback.go +++ b/pkg/openid/client/logout_callback.go @@ -1,34 +1,54 @@ package client import ( + "fmt" "net/http" - mw "github.com/nais/wonderwall/pkg/middleware" + "github.com/nais/wonderwall/pkg/openid" + urlpkg "github.com/nais/wonderwall/pkg/url" ) type LogoutCallback struct { *Client - request *http.Request + cookie *openid.LogoutCookie + validator urlpkg.Validator + request *http.Request } -func NewLogoutCallback(c *Client, r *http.Request) *LogoutCallback { +func NewLogoutCallback(c *Client, r *http.Request, cookie *openid.LogoutCookie, validator urlpkg.Validator) *LogoutCallback { return &LogoutCallback{ - Client: c, - request: r, + Client: c, + cookie: cookie, + validator: validator, + request: r, } } func (in *LogoutCallback) PostLogoutRedirectURI() string { - redirect := in.cfg.Client().PostLogoutRedirectURI() - - if len(redirect) > 0 { - return redirect + if in.cookie != nil && in.stateMismatchError() == nil && in.validator.IsValidRedirect(in.request, in.cookie.RedirectTo) { + return in.cookie.RedirectTo } - ingress, ok := mw.IngressFrom(in.request.Context()) - if !ok { + defaultRedirect := in.cfg.Client().PostLogoutRedirectURI() + if defaultRedirect != "" { + return defaultRedirect + } + + ingress, err := urlpkg.MatchingIngress(in.request) + if err != nil { return "/" } return ingress.String() } + +func (in *LogoutCallback) stateMismatchError() error { + if in.cookie == nil { + return fmt.Errorf("logout cookie is nil") + } + + expectedState := in.cookie.State + actualState := in.request.URL.Query().Get(openid.State) + + return StateMismatchError(expectedState, actualState) +} diff --git a/pkg/openid/client/logout_callback_test.go b/pkg/openid/client/logout_callback_test.go index 9fd4993..c43b91e 100644 --- a/pkg/openid/client/logout_callback_test.go +++ b/pkg/openid/client/logout_callback_test.go @@ -7,36 +7,91 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/mock" + "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/client" + "github.com/nais/wonderwall/pkg/url" ) func TestLogoutCallback_PostLogoutRedirectURI(t *testing.T) { - t.Run("happy path", func(t *testing.T) { - cfg := mock.Config() - cfg.OpenID.PostLogoutRedirectURI = "http://some-fancy-logout-page" + const defaultState = "some-state" + const defaultRedirectURI = "http://some-fancy-logout-page" - lc := newLogoutCallback(cfg) + for _, tt := range []struct { + name string + emptyDefaultURI bool + cookie *openid.LogoutCookie + expected string + }{ + { + name: "happy path", + expected: defaultRedirectURI, + }, + { + name: "empty default uri", + emptyDefaultURI: true, + expected: mock.Ingress, + }, + { + name: "state mismatch", + cookie: &openid.LogoutCookie{ + State: "some-other-state", + }, + expected: defaultRedirectURI, + }, + { + name: "happy path, redirect in cookie", + cookie: &openid.LogoutCookie{ + State: defaultState, + RedirectTo: "http://wonderwall/some/path", + }, + expected: "http://wonderwall/some/path", + }, + { + name: "empty redirect in cookie", + cookie: &openid.LogoutCookie{ + State: defaultState, + RedirectTo: "", + }, + expected: defaultRedirectURI, + }, + { + name: "state mismatch, with redirect in cookie", + cookie: &openid.LogoutCookie{ + State: "some-other-state", + RedirectTo: "http://wonderwall/some/path", + }, + expected: defaultRedirectURI, + }, + { + name: "invalid redirect in cookie", + cookie: &openid.LogoutCookie{ + State: defaultState, + RedirectTo: "http://not-wonderwall/some/path", + }, + expected: defaultRedirectURI, + }, + } { + t.Run(tt.name, func(t *testing.T) { + cfg := mock.Config() + cfg.OpenID.PostLogoutRedirectURI = defaultRedirectURI - uri := lc.PostLogoutRedirectURI() - assert.NotEmpty(t, uri) - assert.Equal(t, "http://some-fancy-logout-page", uri) - }) + if tt.emptyDefaultURI { + cfg.OpenID.PostLogoutRedirectURI = "" + } - t.Run("empty preconfigured post-logout redirect uri", func(t *testing.T) { - cfg := mock.Config() - cfg.OpenID.PostLogoutRedirectURI = "" + lc := newLogoutCallback(cfg, defaultState, tt.cookie) - lc := newLogoutCallback(cfg) - - uri := lc.PostLogoutRedirectURI() - assert.NotEmpty(t, uri) - assert.Equal(t, mock.Ingress, uri) - }) + uri := lc.PostLogoutRedirectURI() + assert.NotEmpty(t, uri) + assert.Equal(t, tt.expected, uri) + }) + } } -func newLogoutCallback(cfg *config.Config) *client.LogoutCallback { +func newLogoutCallback(cfg *config.Config, state string, cookie *openid.LogoutCookie) *client.LogoutCallback { openidCfg := mock.NewTestConfiguration(cfg) ingresses := mock.Ingresses(cfg) - req := mock.NewGetRequest(mock.Ingress+"/oauth2/logout/callback", ingresses) - return newTestClientWithConfig(openidCfg).LogoutCallback(req) + validator := url.NewAbsoluteValidator(ingresses.Hosts()) + req := mock.NewGetRequest(mock.Ingress+"/oauth2/logout/callback?state="+state, ingresses) + return newTestClientWithConfig(openidCfg).LogoutCallback(req, cookie, validator) } diff --git a/pkg/openid/client/logout_test.go b/pkg/openid/client/logout_test.go index 57c0360..4579d73 100644 --- a/pkg/openid/client/logout_test.go +++ b/pkg/openid/client/logout_test.go @@ -20,6 +20,34 @@ func TestLogout_SingleLogoutURL(t *testing.T) { t.Run("with id_token", func(t *testing.T) { logout := newLogout(t) idToken := "some-id-token" + state := logout.Cookie.State + + raw := logout.SingleLogoutURL(idToken) + assert.NotEmpty(t, raw) + + logoutUrl, err := url.Parse(raw) + assert.NoError(t, err) + + query := logoutUrl.Query() + assert.Len(t, query, 3) + + assert.Contains(t, query, "id_token_hint") + assert.Equal(t, idToken, query.Get("id_token_hint")) + + assert.Contains(t, query, "post_logout_redirect_uri") + assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri")) + + assert.Contains(t, query, "state") + assert.Equal(t, state, query.Get("state")) + + logoutUrl.RawQuery = "" + assert.Equal(t, EndSessionEndpoint, logoutUrl.String()) + }) + + t.Run("without id_token", func(t *testing.T) { + logout := newLogout(t) + idToken := "" + state := logout.Cookie.State raw := logout.SingleLogoutURL(idToken) assert.NotEmpty(t, raw) @@ -30,35 +58,15 @@ func TestLogout_SingleLogoutURL(t *testing.T) { query := logoutUrl.Query() assert.Len(t, query, 2) - assert.Contains(t, query, "id_token_hint") - assert.Equal(t, idToken, query.Get("id_token_hint")) - - assert.Contains(t, query, "post_logout_redirect_uri") - assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri")) - - logoutUrl.RawQuery = "" - assert.Equal(t, EndSessionEndpoint, logoutUrl.String()) - }) - - t.Run("without id_token", func(t *testing.T) { - logout := newLogout(t) - idToken := "" - - raw := logout.SingleLogoutURL(idToken) - assert.NotEmpty(t, raw) - - logoutUrl, err := url.Parse(raw) - assert.NoError(t, err) - - query := logoutUrl.Query() - assert.Len(t, query, 1) - assert.NotContains(t, query, "id_token_hint") assert.Equal(t, idToken, query.Get("id_token_hint")) assert.Contains(t, query, "post_logout_redirect_uri") assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri")) + assert.Contains(t, query, "state") + assert.Equal(t, state, query.Get("state")) + logoutUrl.RawQuery = "" assert.Equal(t, EndSessionEndpoint, logoutUrl.String()) }) diff --git a/pkg/url/redirect.go b/pkg/url/redirect.go index 5126d2f..9c1bed0 100644 --- a/pkg/url/redirect.go +++ b/pkg/url/redirect.go @@ -15,6 +15,8 @@ type Redirect interface { Canonical(r *http.Request) string // Clean parses and cleans a target URL according to implementation-specific validations. It should always return a fallback URL string, regardless of validation errors. Clean(r *http.Request, target string) string + // GetValidator returns the Validator used to Clean URLs. + GetValidator() Validator } var _ Redirect = &StandaloneRedirect{} @@ -33,7 +35,7 @@ func (h *StandaloneRedirect) Canonical(r *http.Request) string { target := redirectQueryParam(r) redirect, err := url.ParseRequestURI(target) if err != nil { - redirect = fallback(r, target, h.FallbackRedirect(r)) + redirect = fallback(r, target, h.getFallbackRedirect(r)) } // redirect must be a relative URL to avoid cross-domain redirects @@ -44,10 +46,10 @@ func (h *StandaloneRedirect) Canonical(r *http.Request) string { } func (h *StandaloneRedirect) Clean(r *http.Request, target string) string { - return h.clean(r, target, h.FallbackRedirect(r)) + return h.clean(r, target, h.getFallbackRedirect(r)) } -func (h *StandaloneRedirect) FallbackRedirect(r *http.Request) *url.URL { +func (h *StandaloneRedirect) getFallbackRedirect(r *http.Request) *url.URL { return MatchingPath(r) } @@ -145,6 +147,10 @@ func newRelativeCleaner(allowedHosts []string) *cleaner { } } +func (c *cleaner) GetValidator() Validator { + return c.Validator +} + func (c *cleaner) clean(r *http.Request, target string, fallbackTarget *url.URL) string { if c.IsValidRedirect(r, target) { return target diff --git a/pkg/url/url.go b/pkg/url/url.go index f033bbb..2fd32a5 100644 --- a/pkg/url/url.go +++ b/pkg/url/url.go @@ -45,9 +45,11 @@ func LoginRelative(prefix, redirect string) string { func Logout(target *url.URL, redirect string) string { u := target.JoinPath(paths.OAuth2, paths.Logout) - v := u.Query() - v.Set(RedirectQueryParameter, redirect) - u.RawQuery = v.Encode() + if len(redirect) > 0 { + v := u.Query() + v.Set(RedirectQueryParameter, redirect) + u.RawQuery = v.Encode() + } return u.String() } diff --git a/pkg/url/url_test.go b/pkg/url/url_test.go index 21f658b..80be676 100644 --- a/pkg/url/url_test.go +++ b/pkg/url/url_test.go @@ -122,6 +122,12 @@ func TestLogout(t *testing.T) { redirectTarget: "/path?some=param&other=param2", want: "https://sso.wonderwall/oauth2/logout?redirect=%2Fpath%3Fsome%3Dparam%26other%3Dparam2", }, + { + name: "empty redirect target", + targetURL: "https://sso.wonderwall", + redirectTarget: "", + want: "https://sso.wonderwall/oauth2/logout", + }, } { t.Run(test.name, func(t *testing.T) { targetURL, err := url.Parse(test.targetURL)