mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-06 08:27:10 +00:00
feat(openid, handler): support runtime override of redirect after single-logout
Fixes #100.
This commit is contained in:
@@ -39,7 +39,6 @@ type Standalone struct {
|
||||
CookieOptions cookie.Options
|
||||
Crypter crypto.Crypter
|
||||
Ingresses *ingress.Ingresses
|
||||
OpenidConfig openidconfig.Config
|
||||
Redirect url.Redirect
|
||||
SessionManager session.Manager
|
||||
UpstreamProxy *ReverseProxy
|
||||
@@ -86,7 +85,6 @@ func NewStandalone(
|
||||
CookieOptions: cookieOpts,
|
||||
Crypter: crypter,
|
||||
Ingresses: ingresses,
|
||||
OpenidConfig: openidConfig,
|
||||
Redirect: url.NewStandaloneRedirect(ingresses),
|
||||
SessionManager: sessionManager,
|
||||
UpstreamProxy: NewReverseProxy(upstream, true),
|
||||
@@ -255,7 +253,21 @@ func (s *Standalone) logout(w http.ResponseWriter, r *http.Request, globalLogout
|
||||
cookie.Clear(w, cookie.Session, s.GetCookieOptions(r))
|
||||
|
||||
if globalLogout {
|
||||
logger.Debug("logout: redirecting to identity provider for global/single-logout")
|
||||
// only set a canonical redirect if it was provided in the request as a query parameter
|
||||
canonicalRedirect := r.URL.Query().Get(url.RedirectQueryParameter)
|
||||
if canonicalRedirect != "" {
|
||||
canonicalRedirect = s.Redirect.Canonical(r)
|
||||
}
|
||||
|
||||
opts := s.CookieOptions.WithExpiresIn(5 * time.Minute)
|
||||
err = logout.SetCookie(w, opts, s.Crypter, canonicalRedirect)
|
||||
if err != nil {
|
||||
s.InternalError(w, r, fmt.Errorf("logout: setting logout cookie: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
logger.WithField("redirect_after_logout", canonicalRedirect).
|
||||
Info("logout: redirecting to identity provider for global/single-logout")
|
||||
metrics.ObserveLogout(metrics.LogoutOperationSelfInitiated)
|
||||
http.Redirect(w, r, logout.SingleLogoutURL(idToken), http.StatusFound)
|
||||
} else {
|
||||
@@ -266,10 +278,19 @@ func (s *Standalone) logout(w http.ResponseWriter, r *http.Request, globalLogout
|
||||
}
|
||||
|
||||
func (s *Standalone) LogoutCallback(w http.ResponseWriter, r *http.Request) {
|
||||
redirect := s.Client.LogoutCallback(r).PostLogoutRedirectURI()
|
||||
logger := mw.LogEntryFrom(r)
|
||||
cookie.Clear(w, cookie.Logout, s.CookieOptions)
|
||||
|
||||
logoutCookie, err := openid.GetLogoutCookie(r, s.Crypter)
|
||||
if err != nil {
|
||||
logger.Debugf("logout/callback: getting cookie: %+v; ignoring...", err)
|
||||
}
|
||||
|
||||
logoutCallback := s.Client.LogoutCallback(r, logoutCookie, s.Redirect.GetValidator())
|
||||
redirect := logoutCallback.PostLogoutRedirectURI()
|
||||
|
||||
cookie.Clear(w, cookie.Retry, s.GetCookieOptions(r))
|
||||
mw.LogEntryFrom(r).Debugf("logout/callback: redirecting to %s", redirect)
|
||||
logger.Infof("logout/callback: redirecting to %q", redirect)
|
||||
http.Redirect(w, r, redirect, http.StatusFound)
|
||||
}
|
||||
|
||||
|
||||
@@ -141,8 +141,21 @@ func (s *SSOProxy) LoginCallback(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *SSOProxy) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
target := s.GetSSOServerURL().JoinPath(paths.OAuth2, paths.Logout)
|
||||
http.Redirect(w, r, target.String(), http.StatusFound)
|
||||
target := s.GetSSOServerURL()
|
||||
|
||||
// only set a canonical redirect if it was provided in the request as a query parameter
|
||||
canonicalRedirect := r.URL.Query().Get(url.RedirectQueryParameter)
|
||||
if canonicalRedirect != "" {
|
||||
canonicalRedirect = s.Redirect.Canonical(r)
|
||||
}
|
||||
ssoServerLogoutURL := url.Logout(target, canonicalRedirect)
|
||||
|
||||
logentry.LogEntryFrom(r).WithFields(log.Fields{
|
||||
"redirect_to": ssoServerLogoutURL,
|
||||
"redirect_after_logout": canonicalRedirect,
|
||||
}).Info("logout: redirecting to sso server")
|
||||
|
||||
http.Redirect(w, r, ssoServerLogoutURL, http.StatusFound)
|
||||
}
|
||||
|
||||
func (s *SSOProxy) LogoutCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -90,24 +90,7 @@ func TestLogout(t *testing.T) {
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
|
||||
resp := selfInitiatedLogout(t, rpClient, idp)
|
||||
|
||||
// Get endsession endpoint after local logout
|
||||
endsessionURL := resp.Location
|
||||
|
||||
idpserverURL, err := url.Parse(idp.ProviderServer.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout/callback")
|
||||
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
endsessionParams := endsessionURL.Query()
|
||||
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
|
||||
assert.Equal(t, "/endsession", endsessionURL.Path)
|
||||
assert.Equal(t, []string{expectedLogoutCallbackURL}, endsessionParams["post_logout_redirect_uri"])
|
||||
assert.NotEmpty(t, endsessionParams["id_token_hint"])
|
||||
selfInitiatedLogout(t, rpClient, idp)
|
||||
}
|
||||
|
||||
func TestLogoutLocal(t *testing.T) {
|
||||
@@ -131,6 +114,18 @@ func TestLogoutCallback(t *testing.T) {
|
||||
logout(t, rpClient, idp)
|
||||
}
|
||||
|
||||
func TestLogoutCallback_WithRedirect(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
redirect := idp.RelyingPartyServer.URL + "/api/me"
|
||||
|
||||
rpClient := idp.RelyingPartyClient()
|
||||
login(t, rpClient, idp)
|
||||
logout(t, rpClient, idp, redirect)
|
||||
}
|
||||
|
||||
func TestFrontChannelLogout(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
@@ -485,11 +480,17 @@ func login(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) *htt
|
||||
return callback(t, rpClient, resp)
|
||||
}
|
||||
|
||||
func selfInitiatedLogout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) response {
|
||||
func selfInitiatedLogout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider, redirectAfterLogout ...string) response {
|
||||
// Request self-initiated logout
|
||||
logoutURL, err := url.Parse(idp.RelyingPartyServer.URL + "/oauth2/logout")
|
||||
assert.NoError(t, err)
|
||||
|
||||
if len(redirectAfterLogout) > 0 {
|
||||
v := url.Values{}
|
||||
v.Set(urlpkg.RedirectQueryParameter, redirectAfterLogout[0])
|
||||
logoutURL.RawQuery = v.Encode()
|
||||
}
|
||||
|
||||
resp := get(t, rpClient, logoutURL.String())
|
||||
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
||||
|
||||
@@ -498,21 +499,44 @@ func selfInitiatedLogout(t *testing.T, rpClient *http.Client, idp *mock.Identity
|
||||
|
||||
assert.Nil(t, sessionCookie)
|
||||
|
||||
// Get endsession endpoint after local logout
|
||||
endsessionURL := resp.Location
|
||||
|
||||
idpserverURL, err := url.Parse(idp.ProviderServer.URL)
|
||||
assert.NoError(t, err)
|
||||
|
||||
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout")
|
||||
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
endsessionParams := endsessionURL.Query()
|
||||
assert.Equal(t, idpserverURL.Host, endsessionURL.Host)
|
||||
assert.Equal(t, "/endsession", endsessionURL.Path)
|
||||
assert.Equal(t, expectedLogoutCallbackURL, endsessionParams.Get("post_logout_redirect_uri"))
|
||||
assert.NotEmpty(t, endsessionParams.Get("id_token_hint"))
|
||||
assert.NotEmpty(t, endsessionParams.Get("state"))
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) {
|
||||
func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider, redirectAfterLogout ...string) {
|
||||
// Get endsession endpoint after local logout
|
||||
resp := selfInitiatedLogout(t, rpClient, idp)
|
||||
resp := selfInitiatedLogout(t, rpClient, idp, redirectAfterLogout...)
|
||||
expectedState := resp.Location.Query().Get("state")
|
||||
|
||||
// Follow redirect to endsession endpoint at identity provider
|
||||
resp = get(t, rpClient, resp.Location.String())
|
||||
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
||||
|
||||
// Get post-logout redirect URI after successful logout at identity provider
|
||||
logoutCallbackURI := resp.Location
|
||||
|
||||
req := idp.GetRequest(resp.Location.String())
|
||||
// Assert state for callback equals state sent in initial logout request
|
||||
actualState := logoutCallbackURI.Query().Get("state")
|
||||
assert.NotEmpty(t, actualState)
|
||||
assert.Equal(t, expectedState, actualState)
|
||||
|
||||
// Assert post-logout redirect URI after successful logout at identity provider
|
||||
req := idp.GetRequest(idp.RelyingPartyServer.URL + "/oauth2/logout")
|
||||
expectedLogoutCallbackURL, err := urlpkg.LogoutCallback(req)
|
||||
assert.NoError(t, err)
|
||||
|
||||
@@ -521,10 +545,15 @@ func logout(t *testing.T, rpClient *http.Client, idp *mock.IdentityProvider) {
|
||||
|
||||
// Follow redirect back to logout callback
|
||||
resp = get(t, rpClient, logoutCallbackURI.String())
|
||||
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
||||
|
||||
// Get post-logout redirect URI after redirect back to logout callback
|
||||
assert.Equal(t, "https://google.com", resp.Location.String())
|
||||
assert.Equal(t, http.StatusFound, resp.StatusCode)
|
||||
|
||||
expectedRedirect := "https://google.com"
|
||||
if len(redirectAfterLogout) > 0 {
|
||||
expectedRedirect = redirectAfterLogout[0]
|
||||
}
|
||||
assert.Equal(t, expectedRedirect, resp.Location.String())
|
||||
|
||||
cookies := rpClient.Jar.Cookies(logoutCallbackURI)
|
||||
sessionCookie := getCookieFromJar(cookie.Session, cookies)
|
||||
|
||||
@@ -558,10 +558,17 @@ func (ip *IdentityProviderHandler) validateClientAuthentication(w http.ResponseW
|
||||
func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) {
|
||||
query := r.URL.Query()
|
||||
postLogoutRedirectURI := query.Get("post_logout_redirect_uri")
|
||||
state := query.Get("state")
|
||||
|
||||
if postLogoutRedirectURI == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing required parameters"))
|
||||
w.Write([]byte("missing required 'post_logout_redirect_uri' parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte("missing required 'state' parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -572,6 +579,10 @@ func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Req
|
||||
return
|
||||
}
|
||||
|
||||
v := url.Values{}
|
||||
v.Set("state", state)
|
||||
u.RawQuery = v.Encode()
|
||||
|
||||
http.Redirect(w, r, u.String(), http.StatusFound)
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
|
||||
urlpkg "github.com/nais/wonderwall/pkg/url"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -91,8 +92,8 @@ func (c *Client) Logout(r *http.Request) (*Logout, error) {
|
||||
return logout, nil
|
||||
}
|
||||
|
||||
func (c *Client) LogoutCallback(r *http.Request) *LogoutCallback {
|
||||
return NewLogoutCallback(c, r)
|
||||
func (c *Client) LogoutCallback(r *http.Request, cookie *openid.LogoutCookie, validator urlpkg.Validator) *LogoutCallback {
|
||||
return NewLogoutCallback(c, r, cookie, validator)
|
||||
}
|
||||
|
||||
func (c *Client) LogoutFrontchannel(r *http.Request) *LogoutFrontchannel {
|
||||
@@ -182,3 +183,15 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid
|
||||
|
||||
return &tokenResponse, nil
|
||||
}
|
||||
|
||||
func StateMismatchError(expectedState, actualState string) error {
|
||||
if len(actualState) <= 0 {
|
||||
return fmt.Errorf("missing state parameter in request (possible csrf)")
|
||||
}
|
||||
|
||||
if expectedState != actualState {
|
||||
return fmt.Errorf("state parameter mismatch (possible csrf): expected %s, got %s", expectedState, actualState)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -45,6 +45,22 @@ func TestMakeAssertion(t *testing.T) {
|
||||
assert.True(t, assertion.Expiration().Before(time.Now().Add(expiry)))
|
||||
}
|
||||
|
||||
func TestStateMismatchError(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
name, expected, actual string
|
||||
assertion assert.ErrorAssertionFunc
|
||||
}{
|
||||
{"missing actual state", "expected", "", assert.Error},
|
||||
{"state mismatch", "match", "not-match", assert.Error},
|
||||
{"state match", "match", "match", assert.NoError},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := client.StateMismatchError(tt.expected, tt.actual)
|
||||
tt.assertion(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newTestClientWithConfig(config *mock.TestConfiguration) *client.Client {
|
||||
jwksProvider := mock.NewTestJwksProvider()
|
||||
return client.NewClient(config, jwksProvider)
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
type LoginCallback struct {
|
||||
*Client
|
||||
cookie *openid.LoginCookie
|
||||
request *http.Request
|
||||
requestParams url.Values
|
||||
}
|
||||
|
||||
@@ -37,7 +36,6 @@ func NewLoginCallback(c *Client, r *http.Request, cookie *openid.LoginCookie) (*
|
||||
return &LoginCallback{
|
||||
Client: c,
|
||||
cookie: cookie,
|
||||
request: r,
|
||||
requestParams: r.URL.Query(),
|
||||
}, nil
|
||||
}
|
||||
@@ -56,15 +54,7 @@ func (in *LoginCallback) StateMismatchError() error {
|
||||
expectedState := in.cookie.State
|
||||
actualState := in.requestParams.Get(openid.State)
|
||||
|
||||
if len(actualState) <= 0 {
|
||||
return fmt.Errorf("missing state parameter in request (possible csrf)")
|
||||
}
|
||||
|
||||
if expectedState != actualState {
|
||||
return fmt.Errorf("state parameter mismatch (possible csrf): expected %s, got %s", expectedState, actualState)
|
||||
}
|
||||
|
||||
return nil
|
||||
return StateMismatchError(expectedState, actualState)
|
||||
}
|
||||
|
||||
func (in *LoginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, error) {
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/strings"
|
||||
urlpkg "github.com/nais/wonderwall/pkg/url"
|
||||
)
|
||||
|
||||
type Logout struct {
|
||||
*Client
|
||||
request *http.Request
|
||||
Cookie *openid.LogoutCookie
|
||||
logoutCallbackURL string
|
||||
}
|
||||
|
||||
@@ -20,10 +24,19 @@ func NewLogout(c *Client, r *http.Request) (*Logout, error) {
|
||||
return nil, fmt.Errorf("generating logout callback url: %w", err)
|
||||
}
|
||||
|
||||
state, err := strings.GenerateBase64(32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generating state: %w", err)
|
||||
}
|
||||
|
||||
logoutCookie := &openid.LogoutCookie{
|
||||
State: state,
|
||||
}
|
||||
|
||||
return &Logout{
|
||||
Client: c,
|
||||
Cookie: logoutCookie,
|
||||
logoutCallbackURL: logoutCallbackURL,
|
||||
request: r,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -31,6 +44,7 @@ func (in *Logout) SingleLogoutURL(idToken string) string {
|
||||
endSessionEndpoint := in.cfg.Provider().EndSessionEndpointURL()
|
||||
v := endSessionEndpoint.Query()
|
||||
v.Add(openid.PostLogoutRedirectURI, in.logoutCallbackURL)
|
||||
v.Add(openid.State, in.Cookie.State)
|
||||
|
||||
if len(idToken) > 0 {
|
||||
v.Add(openid.IDTokenHint, idToken)
|
||||
@@ -39,3 +53,15 @@ func (in *Logout) SingleLogoutURL(idToken string) string {
|
||||
endSessionEndpoint.RawQuery = v.Encode()
|
||||
return endSessionEndpoint.String()
|
||||
}
|
||||
|
||||
func (in *Logout) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter crypto.Crypter, canonicalRedirect string) error {
|
||||
in.Cookie.RedirectTo = canonicalRedirect
|
||||
|
||||
logoutCookieJson, err := json.Marshal(in.Cookie)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshalling logout cookie: %w", err)
|
||||
}
|
||||
|
||||
value := string(logoutCookieJson)
|
||||
return cookie.EncryptAndSet(w, cookie.Logout, value, opts, crypter)
|
||||
}
|
||||
|
||||
@@ -1,34 +1,54 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
mw "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
urlpkg "github.com/nais/wonderwall/pkg/url"
|
||||
)
|
||||
|
||||
type LogoutCallback struct {
|
||||
*Client
|
||||
request *http.Request
|
||||
cookie *openid.LogoutCookie
|
||||
validator urlpkg.Validator
|
||||
request *http.Request
|
||||
}
|
||||
|
||||
func NewLogoutCallback(c *Client, r *http.Request) *LogoutCallback {
|
||||
func NewLogoutCallback(c *Client, r *http.Request, cookie *openid.LogoutCookie, validator urlpkg.Validator) *LogoutCallback {
|
||||
return &LogoutCallback{
|
||||
Client: c,
|
||||
request: r,
|
||||
Client: c,
|
||||
cookie: cookie,
|
||||
validator: validator,
|
||||
request: r,
|
||||
}
|
||||
}
|
||||
|
||||
func (in *LogoutCallback) PostLogoutRedirectURI() string {
|
||||
redirect := in.cfg.Client().PostLogoutRedirectURI()
|
||||
|
||||
if len(redirect) > 0 {
|
||||
return redirect
|
||||
if in.cookie != nil && in.stateMismatchError() == nil && in.validator.IsValidRedirect(in.request, in.cookie.RedirectTo) {
|
||||
return in.cookie.RedirectTo
|
||||
}
|
||||
|
||||
ingress, ok := mw.IngressFrom(in.request.Context())
|
||||
if !ok {
|
||||
defaultRedirect := in.cfg.Client().PostLogoutRedirectURI()
|
||||
if defaultRedirect != "" {
|
||||
return defaultRedirect
|
||||
}
|
||||
|
||||
ingress, err := urlpkg.MatchingIngress(in.request)
|
||||
if err != nil {
|
||||
return "/"
|
||||
}
|
||||
|
||||
return ingress.String()
|
||||
}
|
||||
|
||||
func (in *LogoutCallback) stateMismatchError() error {
|
||||
if in.cookie == nil {
|
||||
return fmt.Errorf("logout cookie is nil")
|
||||
}
|
||||
|
||||
expectedState := in.cookie.State
|
||||
actualState := in.request.URL.Query().Get(openid.State)
|
||||
|
||||
return StateMismatchError(expectedState, actualState)
|
||||
}
|
||||
|
||||
@@ -7,36 +7,91 @@ import (
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/openid/client"
|
||||
"github.com/nais/wonderwall/pkg/url"
|
||||
)
|
||||
|
||||
func TestLogoutCallback_PostLogoutRedirectURI(t *testing.T) {
|
||||
t.Run("happy path", func(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.OpenID.PostLogoutRedirectURI = "http://some-fancy-logout-page"
|
||||
const defaultState = "some-state"
|
||||
const defaultRedirectURI = "http://some-fancy-logout-page"
|
||||
|
||||
lc := newLogoutCallback(cfg)
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
emptyDefaultURI bool
|
||||
cookie *openid.LogoutCookie
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
expected: defaultRedirectURI,
|
||||
},
|
||||
{
|
||||
name: "empty default uri",
|
||||
emptyDefaultURI: true,
|
||||
expected: mock.Ingress,
|
||||
},
|
||||
{
|
||||
name: "state mismatch",
|
||||
cookie: &openid.LogoutCookie{
|
||||
State: "some-other-state",
|
||||
},
|
||||
expected: defaultRedirectURI,
|
||||
},
|
||||
{
|
||||
name: "happy path, redirect in cookie",
|
||||
cookie: &openid.LogoutCookie{
|
||||
State: defaultState,
|
||||
RedirectTo: "http://wonderwall/some/path",
|
||||
},
|
||||
expected: "http://wonderwall/some/path",
|
||||
},
|
||||
{
|
||||
name: "empty redirect in cookie",
|
||||
cookie: &openid.LogoutCookie{
|
||||
State: defaultState,
|
||||
RedirectTo: "",
|
||||
},
|
||||
expected: defaultRedirectURI,
|
||||
},
|
||||
{
|
||||
name: "state mismatch, with redirect in cookie",
|
||||
cookie: &openid.LogoutCookie{
|
||||
State: "some-other-state",
|
||||
RedirectTo: "http://wonderwall/some/path",
|
||||
},
|
||||
expected: defaultRedirectURI,
|
||||
},
|
||||
{
|
||||
name: "invalid redirect in cookie",
|
||||
cookie: &openid.LogoutCookie{
|
||||
State: defaultState,
|
||||
RedirectTo: "http://not-wonderwall/some/path",
|
||||
},
|
||||
expected: defaultRedirectURI,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.OpenID.PostLogoutRedirectURI = defaultRedirectURI
|
||||
|
||||
uri := lc.PostLogoutRedirectURI()
|
||||
assert.NotEmpty(t, uri)
|
||||
assert.Equal(t, "http://some-fancy-logout-page", uri)
|
||||
})
|
||||
if tt.emptyDefaultURI {
|
||||
cfg.OpenID.PostLogoutRedirectURI = ""
|
||||
}
|
||||
|
||||
t.Run("empty preconfigured post-logout redirect uri", func(t *testing.T) {
|
||||
cfg := mock.Config()
|
||||
cfg.OpenID.PostLogoutRedirectURI = ""
|
||||
lc := newLogoutCallback(cfg, defaultState, tt.cookie)
|
||||
|
||||
lc := newLogoutCallback(cfg)
|
||||
|
||||
uri := lc.PostLogoutRedirectURI()
|
||||
assert.NotEmpty(t, uri)
|
||||
assert.Equal(t, mock.Ingress, uri)
|
||||
})
|
||||
uri := lc.PostLogoutRedirectURI()
|
||||
assert.NotEmpty(t, uri)
|
||||
assert.Equal(t, tt.expected, uri)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newLogoutCallback(cfg *config.Config) *client.LogoutCallback {
|
||||
func newLogoutCallback(cfg *config.Config, state string, cookie *openid.LogoutCookie) *client.LogoutCallback {
|
||||
openidCfg := mock.NewTestConfiguration(cfg)
|
||||
ingresses := mock.Ingresses(cfg)
|
||||
req := mock.NewGetRequest(mock.Ingress+"/oauth2/logout/callback", ingresses)
|
||||
return newTestClientWithConfig(openidCfg).LogoutCallback(req)
|
||||
validator := url.NewAbsoluteValidator(ingresses.Hosts())
|
||||
req := mock.NewGetRequest(mock.Ingress+"/oauth2/logout/callback?state="+state, ingresses)
|
||||
return newTestClientWithConfig(openidCfg).LogoutCallback(req, cookie, validator)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,34 @@ func TestLogout_SingleLogoutURL(t *testing.T) {
|
||||
t.Run("with id_token", func(t *testing.T) {
|
||||
logout := newLogout(t)
|
||||
idToken := "some-id-token"
|
||||
state := logout.Cookie.State
|
||||
|
||||
raw := logout.SingleLogoutURL(idToken)
|
||||
assert.NotEmpty(t, raw)
|
||||
|
||||
logoutUrl, err := url.Parse(raw)
|
||||
assert.NoError(t, err)
|
||||
|
||||
query := logoutUrl.Query()
|
||||
assert.Len(t, query, 3)
|
||||
|
||||
assert.Contains(t, query, "id_token_hint")
|
||||
assert.Equal(t, idToken, query.Get("id_token_hint"))
|
||||
|
||||
assert.Contains(t, query, "post_logout_redirect_uri")
|
||||
assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
|
||||
|
||||
assert.Contains(t, query, "state")
|
||||
assert.Equal(t, state, query.Get("state"))
|
||||
|
||||
logoutUrl.RawQuery = ""
|
||||
assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
|
||||
})
|
||||
|
||||
t.Run("without id_token", func(t *testing.T) {
|
||||
logout := newLogout(t)
|
||||
idToken := ""
|
||||
state := logout.Cookie.State
|
||||
|
||||
raw := logout.SingleLogoutURL(idToken)
|
||||
assert.NotEmpty(t, raw)
|
||||
@@ -30,35 +58,15 @@ func TestLogout_SingleLogoutURL(t *testing.T) {
|
||||
query := logoutUrl.Query()
|
||||
assert.Len(t, query, 2)
|
||||
|
||||
assert.Contains(t, query, "id_token_hint")
|
||||
assert.Equal(t, idToken, query.Get("id_token_hint"))
|
||||
|
||||
assert.Contains(t, query, "post_logout_redirect_uri")
|
||||
assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
|
||||
|
||||
logoutUrl.RawQuery = ""
|
||||
assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
|
||||
})
|
||||
|
||||
t.Run("without id_token", func(t *testing.T) {
|
||||
logout := newLogout(t)
|
||||
idToken := ""
|
||||
|
||||
raw := logout.SingleLogoutURL(idToken)
|
||||
assert.NotEmpty(t, raw)
|
||||
|
||||
logoutUrl, err := url.Parse(raw)
|
||||
assert.NoError(t, err)
|
||||
|
||||
query := logoutUrl.Query()
|
||||
assert.Len(t, query, 1)
|
||||
|
||||
assert.NotContains(t, query, "id_token_hint")
|
||||
assert.Equal(t, idToken, query.Get("id_token_hint"))
|
||||
|
||||
assert.Contains(t, query, "post_logout_redirect_uri")
|
||||
assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri"))
|
||||
|
||||
assert.Contains(t, query, "state")
|
||||
assert.Equal(t, state, query.Get("state"))
|
||||
|
||||
logoutUrl.RawQuery = ""
|
||||
assert.Equal(t, EndSessionEndpoint, logoutUrl.String())
|
||||
})
|
||||
|
||||
@@ -15,6 +15,8 @@ type Redirect interface {
|
||||
Canonical(r *http.Request) string
|
||||
// Clean parses and cleans a target URL according to implementation-specific validations. It should always return a fallback URL string, regardless of validation errors.
|
||||
Clean(r *http.Request, target string) string
|
||||
// GetValidator returns the Validator used to Clean URLs.
|
||||
GetValidator() Validator
|
||||
}
|
||||
|
||||
var _ Redirect = &StandaloneRedirect{}
|
||||
@@ -33,7 +35,7 @@ func (h *StandaloneRedirect) Canonical(r *http.Request) string {
|
||||
target := redirectQueryParam(r)
|
||||
redirect, err := url.ParseRequestURI(target)
|
||||
if err != nil {
|
||||
redirect = fallback(r, target, h.FallbackRedirect(r))
|
||||
redirect = fallback(r, target, h.getFallbackRedirect(r))
|
||||
}
|
||||
|
||||
// redirect must be a relative URL to avoid cross-domain redirects
|
||||
@@ -44,10 +46,10 @@ func (h *StandaloneRedirect) Canonical(r *http.Request) string {
|
||||
}
|
||||
|
||||
func (h *StandaloneRedirect) Clean(r *http.Request, target string) string {
|
||||
return h.clean(r, target, h.FallbackRedirect(r))
|
||||
return h.clean(r, target, h.getFallbackRedirect(r))
|
||||
}
|
||||
|
||||
func (h *StandaloneRedirect) FallbackRedirect(r *http.Request) *url.URL {
|
||||
func (h *StandaloneRedirect) getFallbackRedirect(r *http.Request) *url.URL {
|
||||
return MatchingPath(r)
|
||||
}
|
||||
|
||||
@@ -145,6 +147,10 @@ func newRelativeCleaner(allowedHosts []string) *cleaner {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cleaner) GetValidator() Validator {
|
||||
return c.Validator
|
||||
}
|
||||
|
||||
func (c *cleaner) clean(r *http.Request, target string, fallbackTarget *url.URL) string {
|
||||
if c.IsValidRedirect(r, target) {
|
||||
return target
|
||||
|
||||
@@ -45,9 +45,11 @@ func LoginRelative(prefix, redirect string) string {
|
||||
func Logout(target *url.URL, redirect string) string {
|
||||
u := target.JoinPath(paths.OAuth2, paths.Logout)
|
||||
|
||||
v := u.Query()
|
||||
v.Set(RedirectQueryParameter, redirect)
|
||||
u.RawQuery = v.Encode()
|
||||
if len(redirect) > 0 {
|
||||
v := u.Query()
|
||||
v.Set(RedirectQueryParameter, redirect)
|
||||
u.RawQuery = v.Encode()
|
||||
}
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
@@ -122,6 +122,12 @@ func TestLogout(t *testing.T) {
|
||||
redirectTarget: "/path?some=param&other=param2",
|
||||
want: "https://sso.wonderwall/oauth2/logout?redirect=%2Fpath%3Fsome%3Dparam%26other%3Dparam2",
|
||||
},
|
||||
{
|
||||
name: "empty redirect target",
|
||||
targetURL: "https://sso.wonderwall",
|
||||
redirectTarget: "",
|
||||
want: "https://sso.wonderwall/oauth2/logout",
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
targetURL, err := url.Parse(test.targetURL)
|
||||
|
||||
Reference in New Issue
Block a user