feat(openid, handler): support runtime override of redirect after single-logout

Fixes #100.
This commit is contained in:
Trong Huu Nguyen
2023-05-04 14:31:59 +02:00
parent b0bb1aa8ea
commit 6151aa3279
14 changed files with 324 additions and 108 deletions

View File

@@ -39,7 +39,6 @@ type Standalone struct {
CookieOptions cookie.Options
Crypter crypto.Crypter
Ingresses *ingress.Ingresses
OpenidConfig openidconfig.Config
Redirect url.Redirect
SessionManager session.Manager
UpstreamProxy *ReverseProxy
@@ -86,7 +85,6 @@ func NewStandalone(
CookieOptions: cookieOpts,
Crypter: crypter,
Ingresses: ingresses,
OpenidConfig: openidConfig,
Redirect: url.NewStandaloneRedirect(ingresses),
SessionManager: sessionManager,
UpstreamProxy: NewReverseProxy(upstream, true),
@@ -255,7 +253,21 @@ func (s *Standalone) logout(w http.ResponseWriter, r *http.Request, globalLogout
cookie.Clear(w, cookie.Session, s.GetCookieOptions(r))
if globalLogout {
logger.Debug("logout: redirecting to identity provider for global/single-logout")
// only set a canonical redirect if it was provided in the request as a query parameter
canonicalRedirect := r.URL.Query().Get(url.RedirectQueryParameter)
if canonicalRedirect != "" {
canonicalRedirect = s.Redirect.Canonical(r)
}
opts := s.CookieOptions.WithExpiresIn(5 * time.Minute)
err = logout.SetCookie(w, opts, s.Crypter, canonicalRedirect)
if err != nil {
s.InternalError(w, r, fmt.Errorf("logout: setting logout cookie: %w", err))
return
}
logger.WithField("redirect_after_logout", canonicalRedirect).
Info("logout: redirecting to identity provider for global/single-logout")
metrics.ObserveLogout(metrics.LogoutOperationSelfInitiated)
http.Redirect(w, r, logout.SingleLogoutURL(idToken), http.StatusFound)
} else {
@@ -266,10 +278,19 @@ func (s *Standalone) logout(w http.ResponseWriter, r *http.Request, globalLogout
}
func (s *Standalone) LogoutCallback(w http.ResponseWriter, r *http.Request) {
redirect := s.Client.LogoutCallback(r).PostLogoutRedirectURI()
logger := mw.LogEntryFrom(r)
cookie.Clear(w, cookie.Logout, s.CookieOptions)
logoutCookie, err := openid.GetLogoutCookie(r, s.Crypter)
if err != nil {
logger.Debugf("logout/callback: getting cookie: %+v; ignoring...", err)
}
logoutCallback := s.Client.LogoutCallback(r, logoutCookie, s.Redirect.GetValidator())
redirect := logoutCallback.PostLogoutRedirectURI()
cookie.Clear(w, cookie.Retry, s.GetCookieOptions(r))
mw.LogEntryFrom(r).Debugf("logout/callback: redirecting to %s", redirect)
logger.Infof("logout/callback: redirecting to %q", redirect)
http.Redirect(w, r, redirect, http.StatusFound)
}

View File

@@ -141,8 +141,21 @@ func (s *SSOProxy) LoginCallback(w http.ResponseWriter, r *http.Request) {
}
func (s *SSOProxy) Logout(w http.ResponseWriter, r *http.Request) {
target := s.GetSSOServerURL().JoinPath(paths.OAuth2, paths.Logout)
http.Redirect(w, r, target.String(), http.StatusFound)
target := s.GetSSOServerURL()
// only set a canonical redirect if it was provided in the request as a query parameter
canonicalRedirect := r.URL.Query().Get(url.RedirectQueryParameter)
if canonicalRedirect != "" {
canonicalRedirect = s.Redirect.Canonical(r)
}
ssoServerLogoutURL := url.Logout(target, canonicalRedirect)
logentry.LogEntryFrom(r).WithFields(log.Fields{
"redirect_to": ssoServerLogoutURL,
"redirect_after_logout": canonicalRedirect,
}).Info("logout: redirecting to sso server")
http.Redirect(w, r, ssoServerLogoutURL, http.StatusFound)
}
func (s *SSOProxy) LogoutCallback(w http.ResponseWriter, r *http.Request) {

View File

@@ -90,24 +90,7 @@ func TestLogout(t *testing.T) {
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
resp := selfInitiatedLogout(t, rpClient, idp)
// Get endsession endpoint after local logout
endsessionURL := resp.Location
idpserverURL, err := url.Parse(idp.ProviderServer.URL)
assert.NoError(t, err)
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout/callback")
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
assert.NoError(t, err)
endsessionParams := endsessionURL.Query()
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
assert.Equal(t, "/endsession", endsessionURL.Path)
assert.Equal(t, []string{expectedLogoutCallbackURL}, endsessionParams["post_logout_redirect_uri"])
assert.NotEmpty(t, endsessionParams["id_token_hint"])
selfInitiatedLogout(t, rpClient, idp)
}
func TestLogoutLocal(t *testing.T) {
@@ -131,6 +114,18 @@ func TestLogoutCallback(t *testing.T) {
logout(t, rpClient, idp)
}
func TestLogoutCallback_WithRedirect(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
defer idp.Close()
redirect := idp.RelyingPartyServer.URL + "/api/me"
rpClient := idp.RelyingPartyClient()
login(t, rpClient, idp)
logout(t, rpClient, idp, redirect)
}
func TestFrontChannelLogout(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
@@ -485,11 +480,17 @@ func login(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) *htt
return callback(t, rpClient, resp)
}
func selfInitiatedLogout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
func selfInitiatedLogout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider, redirectAfterLogout ...string) response {
// Request self-initiated logout
logoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout")
assert.NoError(t, err)
if len(redirectAfterLogout) > 0 {
v := url.Values{}
v.Set(urlpkg.RedirectQueryParameter, redirectAfterLogout[0])
logoutURL.RawQuery = v.Encode()
}
resp := get(t, rpClient, logoutURL.String())
assert.Equal(t, http.StatusFound, resp.StatusCode)
@@ -498,21 +499,44 @@ func selfInitiatedLogout(t *testing.T, rpClient *http.Client, idp *mock.Identity
assert.Nil(t, sessionCookie)
// Get endsession endpoint after local logout
endsessionURL := resp.Location
idpserverURL, err := url.Parse(idp.ProviderServer.URL)
assert.NoError(t, err)
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout")
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
assert.NoError(t, err)
endsessionParams := endsessionURL.Query()
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
assert.Equal(t, "/endsession", endsessionURL.Path)
assert.Equal(t, expectedLogoutCallbackURL, endsessionParams.Get("post_logout_redirect_uri"))
assert.NotEmpty(t, endsessionParams.Get("id_token_hint"))
assert.NotEmpty(t, endsessionParams.Get("state"))
return resp
}
func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) {
func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider, redirectAfterLogout ...string) {
// Get endsession endpoint after local logout
resp := selfInitiatedLogout(t, rpClient, idp)
resp := selfInitiatedLogout(t, rpClient, idp, redirectAfterLogout...)
expectedState := resp.Location.Query().Get("state")
// Follow redirect to endsession endpoint at identity provider
resp = get(t, rpClient, resp.Location.String())
assert.Equal(t, http.StatusFound, resp.StatusCode)
// Get post-logout redirect URI after successful logout at identity provider
logoutCallbackURI := resp.Location
req := idp.GetRequest(resp.Location.String())
// Assert state for callback equals state sent in initial logout request
actualState := logoutCallbackURI.Query().Get("state")
assert.NotEmpty(t, actualState)
assert.Equal(t, expectedState, actualState)
// Assert post-logout redirect URI after successful logout at identity provider
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout")
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
assert.NoError(t, err)
@@ -521,10 +545,15 @@ func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) {
// Follow redirect back to logout callback
resp = get(t, rpClient, logoutCallbackURI.String())
assert.Equal(t, http.StatusFound, resp.StatusCode)
// Get post-logout redirect URI after redirect back to logout callback
assert.Equal(t, "https://google.com", resp.Location.String())
assert.Equal(t, http.StatusFound, resp.StatusCode)
expectedRedirect := "https://google.com"
if len(redirectAfterLogout) > 0 {
expectedRedirect = redirectAfterLogout[0]
}
assert.Equal(t, expectedRedirect, resp.Location.String())
cookies := rpClient.Jar.Cookies(logoutCallbackURI)
sessionCookie := getCookieFromJar(cookie.Session, cookies)

View File

@@ -558,10 +558,17 @@ func (ip *IdentityProviderHandler) validateClientAuthentication(w http.ResponseW
func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) {
query := r.URL.Query()
postLogoutRedirectURI := query.Get("post_logout_redirect_uri")
state := query.Get("state")
if postLogoutRedirectURI == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing required parameters"))
w.Write([]byte("missing required 'post_logout_redirect_uri' parameter"))
return
}
if state == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing required 'state' parameter"))
return
}
@@ -572,6 +579,10 @@ func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Req
return
}
v := url.Values{}
v.Set("state", state)
u.RawQuery = v.Encode()
http.Redirect(w, r, u.String(), http.StatusFound)
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/nais/wonderwall/pkg/openid"
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
var (
@@ -91,8 +92,8 @@ func (c *Client) Logout(r *http.Request) (*Logout, error) {
return logout, nil
}
func (c *Client) LogoutCallback(r *http.Request) *LogoutCallback {
return NewLogoutCallback(c, r)
func (c *Client) LogoutCallback(r *http.Request, cookie *openid.LogoutCookie, validator urlpkg.Validator) *LogoutCallback {
return NewLogoutCallback(c, r, cookie, validator)
}
func (c *Client) LogoutFrontchannel(r *http.Request) *LogoutFrontchannel {
@@ -182,3 +183,15 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid
return &tokenResponse, nil
}
func StateMismatchError(expectedState, actualState string) error {
if len(actualState) <= 0 {
return fmt.Errorf("missing state parameter in request (possible csrf)")
}
if expectedState != actualState {
return fmt.Errorf("state parameter mismatch (possible csrf): expected %s, got %s", expectedState, actualState)
}
return nil
}

View File

@@ -45,6 +45,22 @@ func TestMakeAssertion(t *testing.T) {
assert.True(t, assertion.Expiration().Before(time.Now().Add(expiry)))
}
func TestStateMismatchError(t *testing.T) {
for _, tt := range []struct {
name, expected, actual string
assertion assert.ErrorAssertionFunc
}{
{"missing actual state", "expected", "", assert.Error},
{"state mismatch", "match", "not-match", assert.Error},
{"state match", "match", "match", assert.NoError},
} {
t.Run(tt.name, func(t *testing.T) {
err := client.StateMismatchError(tt.expected, tt.actual)
tt.assertion(t, err)
})
}
}
func newTestClientWithConfig(config *mock.TestConfiguration) *client.Client {
jwksProvider := mock.NewTestJwksProvider()
return client.NewClient(config, jwksProvider)

View File

@@ -15,7 +15,6 @@ import (
type LoginCallback struct {
*Client
cookie *openid.LoginCookie
request *http.Request
requestParams url.Values
}
@@ -37,7 +36,6 @@ func NewLoginCallback(c *Client, r *http.Request, cookie *openid.LoginCookie) (*
return &LoginCallback{
Client: c,
cookie: cookie,
request: r,
requestParams: r.URL.Query(),
}, nil
}
@@ -56,15 +54,7 @@ func (in *LoginCallback) StateMismatchError() error {
expectedState := in.cookie.State
actualState := in.requestParams.Get(openid.State)
if len(actualState) <= 0 {
return fmt.Errorf("missing state parameter in request (possible csrf)")
}
if expectedState != actualState {
return fmt.Errorf("state parameter mismatch (possible csrf): expected %s, got %s", expectedState, actualState)
}
return nil
return StateMismatchError(expectedState, actualState)
}
func (in *LoginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, error) {

View File

@@ -1,16 +1,20 @@
package client
import (
"encoding/json"
"fmt"
"net/http"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/strings"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
type Logout struct {
*Client
request *http.Request
Cookie *openid.LogoutCookie
logoutCallbackURL string
}
@@ -20,10 +24,19 @@ func NewLogout(c *Client, r *http.Request) (*Logout, error) {
return nil, fmt.Errorf("generating logout callback url: %w", err)
}
state, err := strings.GenerateBase64(32)
if err != nil {
return nil, fmt.Errorf("generating state: %w", err)
}
logoutCookie := &openid.LogoutCookie{
State: state,
}
return &Logout{
Client: c,
Cookie: logoutCookie,
logoutCallbackURL: logoutCallbackURL,
request: r,
}, nil
}
@@ -31,6 +44,7 @@ func (in *Logout) SingleLogoutURL(idToken string) string {
endSessionEndpoint := in.cfg.Provider().EndSessionEndpointURL()
v := endSessionEndpoint.Query()
v.Add(openid.PostLogoutRedirectURI, in.logoutCallbackURL)
v.Add(openid.State, in.Cookie.State)
if len(idToken) > 0 {
v.Add(openid.IDTokenHint, idToken)
@@ -39,3 +53,15 @@ func (in *Logout) SingleLogoutURL(idToken string) string {
endSessionEndpoint.RawQuery = v.Encode()
return endSessionEndpoint.String()
}
func (in *Logout) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter, canonicalRedirect string) error {
in.Cookie.RedirectTo = canonicalRedirect
logoutCookieJson, err := json.Marshal(in.Cookie)
if err != nil {
return fmt.Errorf("marshalling logout cookie: %w", err)
}
value := string(logoutCookieJson)
return cookie.EncryptAndSet(w, cookie.Logout, value, opts, crypter)
}

View File

@@ -1,34 +1,54 @@
package client
import (
"fmt"
"net/http"
mw "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/openid"
urlpkg "github.com/nais/wonderwall/pkg/url"
)
type LogoutCallback struct {
*Client
request *http.Request
cookie *openid.LogoutCookie
validator urlpkg.Validator
request *http.Request
}
func NewLogoutCallback(c *Client, r *http.Request) *LogoutCallback {
func NewLogoutCallback(c *Client, r *http.Request, cookie *openid.LogoutCookie, validator urlpkg.Validator) *LogoutCallback {
return &LogoutCallback{
Client: c,
request: r,
Client: c,
cookie: cookie,
validator: validator,
request: r,
}
}
func (in *LogoutCallback) PostLogoutRedirectURI() string {
redirect := in.cfg.Client().PostLogoutRedirectURI()
if len(redirect) > 0 {
return redirect
if in.cookie != nil && in.stateMismatchError() == nil && in.validator.IsValidRedirect(in.request, in.cookie.RedirectTo) {
return in.cookie.RedirectTo
}
ingress, ok := mw.IngressFrom(in.request.Context())
if !ok {
defaultRedirect := in.cfg.Client().PostLogoutRedirectURI()
if defaultRedirect != "" {
return defaultRedirect
}
ingress, err := urlpkg.MatchingIngress(in.request)
if err != nil {
return "/"
}
return ingress.String()
}
func (in *LogoutCallback) stateMismatchError() error {
if in.cookie == nil {
return fmt.Errorf("logout cookie is nil")
}
expectedState := in.cookie.State
actualState := in.request.URL.Query().Get(openid.State)
return StateMismatchError(expectedState, actualState)
}

View File

@@ -7,36 +7,91 @@ import (
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/client"
"github.com/nais/wonderwall/pkg/url"
)
func TestLogoutCallback_PostLogoutRedirectURI(t *testing.T) {
t.Run("happy path", func(t *testing.T) {
cfg := mock.Config()
cfg.OpenID.PostLogoutRedirectURI = "http://some-fancy-logout-page"
const defaultState = "some-state"
const defaultRedirectURI = "http://some-fancy-logout-page"
lc := newLogoutCallback(cfg)
for _, tt := range []struct {
name string
emptyDefaultURI bool
cookie *openid.LogoutCookie
expected string
}{
{
name: "happy path",
expected: defaultRedirectURI,
},
{
name: "empty default uri",
emptyDefaultURI: true,
expected: mock.Ingress,
},
{
name: "state mismatch",
cookie: &openid.LogoutCookie{
State: "some-other-state",
},
expected: defaultRedirectURI,
},
{
name: "happy path, redirect in cookie",
cookie: &openid.LogoutCookie{
State: defaultState,
RedirectTo: "http://wonderwall/some/path",
},
expected: "http://wonderwall/some/path",
},
{
name: "empty redirect in cookie",
cookie: &openid.LogoutCookie{
State: defaultState,
RedirectTo: "",
},
expected: defaultRedirectURI,
},
{
name: "state mismatch, with redirect in cookie",
cookie: &openid.LogoutCookie{
State: "some-other-state",
RedirectTo: "http://wonderwall/some/path",
},
expected: defaultRedirectURI,
},
{
name: "invalid redirect in cookie",
cookie: &openid.LogoutCookie{
State: defaultState,
RedirectTo: "http://not-wonderwall/some/path",
},
expected: defaultRedirectURI,
},
} {
t.Run(tt.name, func(t *testing.T) {
cfg := mock.Config()
cfg.OpenID.PostLogoutRedirectURI = defaultRedirectURI
uri := lc.PostLogoutRedirectURI()
assert.NotEmpty(t, uri)
assert.Equal(t, "http://some-fancy-logout-page", uri)
})
if tt.emptyDefaultURI {
cfg.OpenID.PostLogoutRedirectURI = ""
}
t.Run("empty preconfigured post-logout redirect uri", func(t *testing.T) {
cfg := mock.Config()
cfg.OpenID.PostLogoutRedirectURI = ""
lc := newLogoutCallback(cfg, defaultState, tt.cookie)
lc := newLogoutCallback(cfg)
uri := lc.PostLogoutRedirectURI()
assert.NotEmpty(t, uri)
assert.Equal(t, mock.Ingress, uri)
})
uri := lc.PostLogoutRedirectURI()
assert.NotEmpty(t, uri)
assert.Equal(t, tt.expected, uri)
})
}
}
func newLogoutCallback(cfg *config.Config) *client.LogoutCallback {
func newLogoutCallback(cfg *config.Config, state string, cookie *openid.LogoutCookie) *client.LogoutCallback {
openidCfg := mock.NewTestConfiguration(cfg)
ingresses := mock.Ingresses(cfg)
req := mock.NewGetRequest(mock.Ingress+"/oauth2/logout/callback", ingresses)
return newTestClientWithConfig(openidCfg).LogoutCallback(req)
validator := url.NewAbsoluteValidator(ingresses.Hosts())
req := mock.NewGetRequest(mock.Ingress+"/oauth2/logout/callback?state="+state, ingresses)
return newTestClientWithConfig(openidCfg).LogoutCallback(req, cookie, validator)
}

View File

@@ -20,6 +20,34 @@ func TestLogout_SingleLogoutURL(t *testing.T) {
t.Run("with id_token", func(t *testing.T) {
logout := newLogout(t)
idToken := "some-id-token"
state := logout.Cookie.State
raw := logout.SingleLogoutURL(idToken)
assert.NotEmpty(t, raw)
logoutUrl, err := url.Parse(raw)
assert.NoError(t, err)
query := logoutUrl.Query()
assert.Len(t, query, 3)
assert.Contains(t, query, "id_token_hint")
assert.Equal(t, idToken, query.Get("id_token_hint"))
assert.Contains(t, query, "post_logout_redirect_uri")
assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
assert.Contains(t, query, "state")
assert.Equal(t, state, query.Get("state"))
logoutUrl.RawQuery = ""
assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
})
t.Run("without id_token", func(t *testing.T) {
logout := newLogout(t)
idToken := ""
state := logout.Cookie.State
raw := logout.SingleLogoutURL(idToken)
assert.NotEmpty(t, raw)
@@ -30,35 +58,15 @@ func TestLogout_SingleLogoutURL(t *testing.T) {
query := logoutUrl.Query()
assert.Len(t, query, 2)
assert.Contains(t, query, "id_token_hint")
assert.Equal(t, idToken, query.Get("id_token_hint"))
assert.Contains(t, query, "post_logout_redirect_uri")
assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
logoutUrl.RawQuery = ""
assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
})
t.Run("without id_token", func(t *testing.T) {
logout := newLogout(t)
idToken := ""
raw := logout.SingleLogoutURL(idToken)
assert.NotEmpty(t, raw)
logoutUrl, err := url.Parse(raw)
assert.NoError(t, err)
query := logoutUrl.Query()
assert.Len(t, query, 1)
assert.NotContains(t, query, "id_token_hint")
assert.Equal(t, idToken, query.Get("id_token_hint"))
assert.Contains(t, query, "post_logout_redirect_uri")
assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
assert.Contains(t, query, "state")
assert.Equal(t, state, query.Get("state"))
logoutUrl.RawQuery = ""
assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
})

View File

@@ -15,6 +15,8 @@ type Redirect interface {
Canonical(r *http.Request) string
// Clean parses and cleans a target URL according to implementation-specific validations. It should always return a fallback URL string, regardless of validation errors.
Clean(r *http.Request, target string) string
// GetValidator returns the Validator used to Clean URLs.
GetValidator() Validator
}
var _ Redirect = &StandaloneRedirect{}
@@ -33,7 +35,7 @@ func (h *StandaloneRedirect) Canonical(r *http.Request) string {
target := redirectQueryParam(r)
redirect, err := url.ParseRequestURI(target)
if err != nil {
redirect = fallback(r, target, h.FallbackRedirect(r))
redirect = fallback(r, target, h.getFallbackRedirect(r))
}
// redirect must be a relative URL to avoid cross-domain redirects
@@ -44,10 +46,10 @@ func (h *StandaloneRedirect) Canonical(r *http.Request) string {
}
func (h *StandaloneRedirect) Clean(r *http.Request, target string) string {
return h.clean(r, target, h.FallbackRedirect(r))
return h.clean(r, target, h.getFallbackRedirect(r))
}
func (h *StandaloneRedirect) FallbackRedirect(r *http.Request) *url.URL {
func (h *StandaloneRedirect) getFallbackRedirect(r *http.Request) *url.URL {
return MatchingPath(r)
}
@@ -145,6 +147,10 @@ func newRelativeCleaner(allowedHosts []string) *cleaner {
}
}
func (c *cleaner) GetValidator() Validator {
return c.Validator
}
func (c *cleaner) clean(r *http.Request, target string, fallbackTarget *url.URL) string {
if c.IsValidRedirect(r, target) {
return target

View File

@@ -45,9 +45,11 @@ func LoginRelative(prefix, redirect string) string {
func Logout(target *url.URL, redirect string) string {
u := target.JoinPath(paths.OAuth2, paths.Logout)
v := u.Query()
v.Set(RedirectQueryParameter, redirect)
u.RawQuery = v.Encode()
if len(redirect) > 0 {
v := u.Query()
v.Set(RedirectQueryParameter, redirect)
u.RawQuery = v.Encode()
}
return u.String()
}

View File

@@ -122,6 +122,12 @@ func TestLogout(t *testing.T) {
redirectTarget: "/path?some=param&other=param2",
want: "https://sso.wonderwall/oauth2/logout?redirect=%2Fpath%3Fsome%3Dparam%26other%3Dparam2",
},
{
name: "empty redirect target",
targetURL: "https://sso.wonderwall",
redirectTarget: "",
want: "https://sso.wonderwall/oauth2/logout",
},
} {
t.Run(test.name, func(t *testing.T) {
targetURL, err := url.Parse(test.targetURL)