From 32dd80b5daeb55b9c7712db4fef1643c21fe0100 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Mon, 9 May 2022 11:50:19 +0200 Subject: [PATCH] feat: add handler for logout callbacks --- README.md | 1 + pkg/mock/client_configuration.go | 14 ++++-- pkg/mock/handler.go | 24 +++++++++ pkg/mock/openid.go | 1 + pkg/openid/clients/configuration.go | 3 +- pkg/openid/clients/openid.go | 22 ++++++--- pkg/openid/cookies.go | 5 ++ pkg/openid/provider.go | 15 ++++-- pkg/openid/redirect_uri.go | 4 +- pkg/openid/redirect_uri_test.go | 13 ++++- pkg/router/cookies.go | 1 + pkg/router/handler.go | 2 +- pkg/router/handler_logout.go | 71 +++++++++++++++++++++++---- pkg/router/handler_logout_callback.go | 63 ++++++++++++++++++++++++ pkg/router/login_url.go | 2 +- pkg/router/login_url_test.go | 2 +- pkg/router/paths/paths.go | 1 + pkg/router/request/parameters.go | 7 ++- pkg/router/request/request.go | 9 ---- pkg/router/router.go | 1 + pkg/router/router_test.go | 52 +++++++++++++++++--- 21 files changed, 262 insertions(+), 51 deletions(-) create mode 100644 pkg/router/handler_logout_callback.go diff --git a/README.md b/README.md index 685a80f..f4b32af 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,7 @@ Wonderwall exposes and owns these endpoints (which means they will never be prox | `/oauth2/login` | Initiates the OpenID Connect Authorization Code flow | | `/oauth2/callback` | Handles the callback from the identity provider | | `/oauth2/logout` | Initiates local and global/single-logout | +| `/oauth2/logout/callback` | Handles the logout callback from the identity provider | | `/oauth2/logout/frontchannel` | Handles global logout request (initiated by identity provider on behalf of another client) | ## Usage diff --git a/pkg/mock/client_configuration.go b/pkg/mock/client_configuration.go index 7a4f2d4..ccdbe46 100644 --- a/pkg/mock/client_configuration.go +++ b/pkg/mock/client_configuration.go @@ -10,7 +10,8 @@ import ( type TestClientConfiguration struct { ClientID string ClientJWK jwk.Key - RedirectURI string + CallbackURI string + LogoutCallbackURI string PostLogoutRedirectURI string Scopes scopes.Scopes ACRValues string @@ -18,8 +19,8 @@ type TestClientConfiguration struct { WellKnownURL string } -func (c TestClientConfiguration) GetRedirectURI() string { - return c.RedirectURI +func (c TestClientConfiguration) GetCallbackURI() string { + return c.CallbackURI } func (c TestClientConfiguration) GetClientID() string { @@ -30,6 +31,10 @@ func (c TestClientConfiguration) GetClientJWK() jwk.Key { return c.ClientJWK } +func (c TestClientConfiguration) GetLogoutCallbackURI() string { + return c.LogoutCallbackURI +} + func (c TestClientConfiguration) GetPostLogoutRedirectURI() string { return c.PostLogoutRedirectURI } @@ -59,7 +64,8 @@ func clientConfiguration() TestClientConfiguration { return TestClientConfiguration{ ClientID: "client_id", ClientJWK: key, - RedirectURI: "http://localhost/callback", + CallbackURI: "http://localhost/callback", + LogoutCallbackURI: "http://localhost/logout/callback", WellKnownURL: "", UILocales: "nb", ACRValues: "Level4", diff --git a/pkg/mock/handler.go b/pkg/mock/handler.go index 5c8eb29..a23489f 100644 --- a/pkg/mock/handler.go +++ b/pkg/mock/handler.go @@ -223,3 +223,27 @@ func (ip *identityProviderHandler) Token(w http.ResponseWriter, r *http.Request) w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(token) } + +func (ip *identityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + state := query.Get("state") + postLogoutRedirectURI := query.Get("post_logout_redirect_uri") + + if state == "" || postLogoutRedirectURI == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("missing required parameters")) + return + } + + u, err := url.Parse(postLogoutRedirectURI) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("couldn't parse post_logout_redirect_uri")) + return + } + v := url.Values{} + v.Set("state", state) + + u.RawQuery = v.Encode() + http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) +} diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index c8e6409..caba55b 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -26,5 +26,6 @@ func identityProviderRouter(ip *identityProviderHandler) chi.Router { r.Get("/authorize", ip.Authorize) r.Post("/token", ip.Token) r.Get("/jwks", ip.Jwks) + r.Get("/endsession", ip.EndSession) return r } diff --git a/pkg/openid/clients/configuration.go b/pkg/openid/clients/configuration.go index 05a9924..42a1b5b 100644 --- a/pkg/openid/clients/configuration.go +++ b/pkg/openid/clients/configuration.go @@ -10,7 +10,8 @@ type Configuration interface { GetClientID() string GetClientJWK() jwk.Key GetPostLogoutRedirectURI() string - GetRedirectURI() string + GetCallbackURI() string + GetLogoutCallbackURI() string GetScopes() scopes.Scopes GetACRValues() string GetUILocales() string diff --git a/pkg/openid/clients/openid.go b/pkg/openid/clients/openid.go index 3f0a92d..9f5847a 100644 --- a/pkg/openid/clients/openid.go +++ b/pkg/openid/clients/openid.go @@ -9,12 +9,13 @@ import ( type OpenIDConfig struct { config.OpenID - clientJwk jwk.Key - redirectURI string + clientJwk jwk.Key + callbackURI string + logoutCallbackURI string } -func (in *OpenIDConfig) GetRedirectURI() string { - return in.redirectURI +func (in *OpenIDConfig) GetCallbackURI() string { + return in.callbackURI } func (in *OpenIDConfig) GetClientID() string { @@ -25,6 +26,10 @@ func (in *OpenIDConfig) GetClientJWK() jwk.Key { return in.clientJwk } +func (in *OpenIDConfig) GetLogoutCallbackURI() string { + return in.logoutCallbackURI +} + func (in *OpenIDConfig) GetPostLogoutRedirectURI() string { return in.PostLogoutRedirectURI } @@ -45,10 +50,11 @@ func (in *OpenIDConfig) GetWellKnownURL() string { return in.WellKnownURL } -func NewOpenIDConfig(cfg config.Config, clientJwk jwk.Key, redirectURI string) *OpenIDConfig { +func NewOpenIDConfig(cfg config.Config, clientJwk jwk.Key, callbackURI, logoutCallbackURI string) *OpenIDConfig { return &OpenIDConfig{ - OpenID: cfg.OpenID, - clientJwk: clientJwk, - redirectURI: redirectURI, + OpenID: cfg.OpenID, + clientJwk: clientJwk, + callbackURI: callbackURI, + logoutCallbackURI: logoutCallbackURI, } } diff --git a/pkg/openid/cookies.go b/pkg/openid/cookies.go index 71167e2..aabcaeb 100644 --- a/pkg/openid/cookies.go +++ b/pkg/openid/cookies.go @@ -6,3 +6,8 @@ type LoginCookie struct { CodeVerifier string `json:"code_verifier"` Referer string `json:"referer"` } + +type LogoutCookie struct { + State string `json:"state"` + RedirectTo string `json:"redirect_to"` +} diff --git a/pkg/openid/provider.go b/pkg/openid/provider.go index 454d5c3..4158c47 100644 --- a/pkg/openid/provider.go +++ b/pkg/openid/provider.go @@ -11,6 +11,7 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/openid/clients" + "github.com/nais/wonderwall/pkg/router/paths" ) const ( @@ -91,12 +92,17 @@ func NewProvider(ctx context.Context, cfg *config.Config) (Provider, error) { return nil, fmt.Errorf("missing required config %s", config.Ingress) } - redirectURI, err := RedirectURI(ingress) + callbackURI, err := RedirectURI(ingress, paths.Callback) if err != nil { - return nil, fmt.Errorf("creating redirect URI from ingress: %w", err) + return nil, fmt.Errorf("creating callback URI from ingress: %w", err) } - openIDConfig := clients.NewOpenIDConfig(*cfg, clientJwk, redirectURI) + logoutCallbackURI, err := RedirectURI(ingress, paths.LogoutCallback) + if err != nil { + return nil, fmt.Errorf("creating logout callback URI from ingress: %w", err) + } + + openIDConfig := clients.NewOpenIDConfig(*cfg, clientJwk, callbackURI, logoutCallbackURI) var clientConfig clients.Configuration switch cfg.OpenID.Provider { case config.ProviderIDPorten: @@ -161,7 +167,8 @@ func printConfigs(clientCfg clients.Configuration, openIdCfg Configuration) { log.Infof("acr values: '%s'", clientCfg.GetACRValues()) log.Infof("client id: '%s'", clientCfg.GetClientID()) log.Infof("post-logout redirect uri: '%s'", clientCfg.GetPostLogoutRedirectURI()) - log.Infof("redirect uri: '%s'", clientCfg.GetRedirectURI()) + log.Infof("callback uri: '%s'", clientCfg.GetCallbackURI()) + log.Infof("logout callback uri: '%s'", clientCfg.GetLogoutCallbackURI()) log.Infof("scopes: '%s'", clientCfg.GetScopes()) log.Infof("ui locales: '%s'", clientCfg.GetUILocales()) diff --git a/pkg/openid/redirect_uri.go b/pkg/openid/redirect_uri.go index d7906aa..4334971 100644 --- a/pkg/openid/redirect_uri.go +++ b/pkg/openid/redirect_uri.go @@ -8,7 +8,7 @@ import ( "github.com/nais/wonderwall/pkg/router/paths" ) -func RedirectURI(ingress string) (string, error) { +func RedirectURI(ingress, redirectPath string) (string, error) { if len(ingress) == 0 { return "", fmt.Errorf("ingress cannot be empty") } @@ -18,6 +18,6 @@ func RedirectURI(ingress string) (string, error) { return "", err } - base.Path = path.Join(base.Path, paths.OAuth2, paths.Callback) + base.Path = path.Join(base.Path, paths.OAuth2, redirectPath) return base.String(), nil } diff --git a/pkg/openid/redirect_uri_test.go b/pkg/openid/redirect_uri_test.go index 0033797..d4a150b 100644 --- a/pkg/openid/redirect_uri_test.go +++ b/pkg/openid/redirect_uri_test.go @@ -7,36 +7,47 @@ import ( "github.com/stretchr/testify/assert" "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/router/paths" ) func TestRedirectURI(t *testing.T) { for _, test := range []struct { input string + path string want string err error }{ { input: "https://nav.no/dagpenger", + path: paths.Callback, want: "https://nav.no/dagpenger/oauth2/callback", }, { input: "https://nav.no/dagpenger/soknad", + path: paths.Callback, want: "https://nav.no/dagpenger/soknad/oauth2/callback", }, { input: "https://nav.no", + path: paths.Callback, want: "https://nav.no/oauth2/callback", }, { input: "https://nav.no/", + path: paths.Callback, want: "https://nav.no/oauth2/callback", }, + { + input: "https://nav.no/", + path: paths.LogoutCallback, + want: "https://nav.no/oauth2/logout/callback", + }, { input: "", err: fmt.Errorf("ingress cannot be empty"), }, } { - actual, err := openid.RedirectURI(test.input) + actual, err := openid.RedirectURI(test.input, test.path) if test.err != nil { assert.EqualError(t, err, test.err.Error()) } else { diff --git a/pkg/router/cookies.go b/pkg/router/cookies.go index e970648..6b17725 100644 --- a/pkg/router/cookies.go +++ b/pkg/router/cookies.go @@ -10,6 +10,7 @@ const ( SessionCookieName = "io.nais.wonderwall.session" LoginCookieName = "io.nais.wonderwall.callback" LoginLegacyCookieName = "io.nais.wonderwall.callback.legacy" + LogoutCookieName = "io.nais.wonderwall.logout" ) func (h *Handler) setEncryptedCookie(w http.ResponseWriter, key string, plaintext string, opts cookie.Options) error { diff --git a/pkg/router/handler.go b/pkg/router/handler.go index 912ee11..695c06c 100644 --- a/pkg/router/handler.go +++ b/pkg/router/handler.go @@ -40,7 +40,7 @@ func NewHandler( AuthURL: provider.GetOpenIDConfiguration().AuthorizationEndpoint, TokenURL: provider.GetOpenIDConfiguration().TokenEndpoint, }, - RedirectURL: provider.GetClientConfiguration().GetRedirectURI(), + RedirectURL: provider.GetClientConfiguration().GetCallbackURI(), Scopes: provider.GetClientConfiguration().GetScopes(), } loginstatusClient := loginstatus.NewClient(cfg.Loginstatus, http.DefaultClient) diff --git a/pkg/router/handler_logout.go b/pkg/router/handler_logout.go index a9c28fb..3dadd8a 100644 --- a/pkg/router/handler_logout.go +++ b/pkg/router/handler_logout.go @@ -1,13 +1,17 @@ package router import ( + "encoding/json" + "errors" "fmt" "net/http" "net/url" - log "github.com/sirupsen/logrus" + "github.com/go-redis/redis/v8" - "github.com/nais/wonderwall/pkg/router/request" + "github.com/nais/wonderwall/pkg/openid" + logentry "github.com/nais/wonderwall/pkg/router/middleware" + "github.com/nais/wonderwall/pkg/strings" ) // Logout triggers self-initiated for the current user @@ -24,12 +28,16 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { if err == nil && sessionData != nil { idToken = sessionData.IDToken err = h.destroySession(w, r, h.localSessionID(sessionData.ExternalSessionID)) - if err != nil { + if err != nil && !errors.Is(err, redis.Nil) { h.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err)) return } - log.WithField("claims", sessionData.Claims).Infof("logout: successful logout") + fields := map[string]interface{}{ + "claims": sessionData.Claims, + } + logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger() + logger.Info().Msg("logout: successful local logout") } h.deleteCookie(w, SessionCookieName, h.CookieOptions) @@ -38,18 +46,63 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { h.Loginstatus.ClearCookie(w, h.CookieOptions) } - v := u.Query() - - postLogoutURI := request.PostLogoutRedirectURI(r, h.Provider.GetClientConfiguration().GetPostLogoutRedirectURI()) - if len(postLogoutURI) > 0 { - v.Add("post_logout_redirect_uri", postLogoutURI) + logoutCookie, err := h.logoutCookie() + if err != nil { + h.InternalError(w, r, fmt.Errorf("logout: generating logout cookie: %w", err)) + return } + err = h.setLogoutCookie(w, logoutCookie) + if err != nil { + h.InternalError(w, r, fmt.Errorf("logout: setting logout cookie: %w", err)) + return + } + + v := u.Query() + v.Add("post_logout_redirect_uri", h.Provider.GetClientConfiguration().GetLogoutCallbackURI()) + v.Add("state", logoutCookie.State) + if len(idToken) > 0 { v.Add("id_token_hint", idToken) } u.RawQuery = v.Encode() + fields := map[string]interface{}{ + "redirect_to": logoutCookie.RedirectTo, + } + logger := logentry.LogEntry(r.Context()).With().Fields(fields).Logger() + logger.Info().Msg("logout: redirecting to identity provider") + http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) } + +func (h *Handler) logoutCookie() (*openid.LogoutCookie, error) { + state, err := strings.GenerateBase64(32) + if err != nil { + return nil, fmt.Errorf("generating state: %w", err) + } + + return &openid.LogoutCookie{ + State: state, + RedirectTo: h.Provider.GetClientConfiguration().GetPostLogoutRedirectURI(), + }, nil +} + +func (h *Handler) setLogoutCookie(w http.ResponseWriter, logoutCookie *openid.LogoutCookie) error { + logoutCookieJson, err := json.Marshal(logoutCookie) + if err != nil { + return fmt.Errorf("marshalling login cookie: %w", err) + } + + opts := h.CookieOptions. + WithExpiresIn(LogoutCookieLifetime) + value := string(logoutCookieJson) + + err = h.setEncryptedCookie(w, LogoutCookieName, value, opts) + if err != nil { + return err + } + + return nil +} diff --git a/pkg/router/handler_logout_callback.go b/pkg/router/handler_logout_callback.go new file mode 100644 index 0000000..21297e6 --- /dev/null +++ b/pkg/router/handler_logout_callback.go @@ -0,0 +1,63 @@ +package router + +import ( + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/nais/wonderwall/pkg/openid" + logentry "github.com/nais/wonderwall/pkg/router/middleware" +) + +const ( + LogoutCookieLifetime = 5 * time.Minute +) + +// LogoutCallback handles the callback from the self-initiated logout for the current user +func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) { + h.deleteCookie(w, LogoutCookieName, h.CookieOptions) + + logger := logentry.LogEntry(r.Context()) + + logoutCookie, err := h.getLogoutCookie(r) + if err != nil { + logger.Warn().Msgf("logout/callback: getting cookie: %+v", err) + http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect) + return + } + + params := r.URL.Query() + expectedState := logoutCookie.State + actualState := params.Get("state") + + if expectedState != actualState { + logger.Warn().Msgf("logout/callback: state parameter mismatch: expected %s, got %s", expectedState, actualState) + http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect) + return + } + + if len(logoutCookie.RedirectTo) == 0 { + logger.Warn().Msgf("logout/callback: empty redirect; falling back to ingress") + http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect) + return + } + + logger.Info().Msgf("logout/callback: redirecting to %s", logoutCookie.RedirectTo) + http.Redirect(w, r, logoutCookie.RedirectTo, http.StatusTemporaryRedirect) +} + +func (h *Handler) getLogoutCookie(r *http.Request) (*openid.LogoutCookie, error) { + logoutCookieJson, err := h.getDecryptedCookie(r, LogoutCookieName) + if err != nil { + return nil, err + } + + var logoutCookie openid.LogoutCookie + err = json.Unmarshal([]byte(logoutCookieJson), &logoutCookie) + if err != nil { + return nil, fmt.Errorf("unmarshalling: %w", err) + } + + return &logoutCookie, nil +} diff --git a/pkg/router/login_url.go b/pkg/router/login_url.go index a77c17d..1990b0e 100644 --- a/pkg/router/login_url.go +++ b/pkg/router/login_url.go @@ -24,7 +24,7 @@ func (h *Handler) LoginURL(r *http.Request, params *openid.LoginParameters) (str v := u.Query() v.Add("response_type", "code") v.Add("client_id", h.Provider.GetClientConfiguration().GetClientID()) - v.Add("redirect_uri", h.Provider.GetClientConfiguration().GetRedirectURI()) + v.Add("redirect_uri", h.Provider.GetClientConfiguration().GetCallbackURI()) v.Add("scope", h.Provider.GetClientConfiguration().GetScopes().String()) v.Add("state", params.State) v.Add("nonce", params.Nonce) diff --git a/pkg/router/login_url_test.go b/pkg/router/login_url_test.go index 2ee8684..01b4c61 100644 --- a/pkg/router/login_url_test.go +++ b/pkg/router/login_url_test.go @@ -81,7 +81,7 @@ func TestLoginURL(t *testing.T) { assert.ElementsMatch(t, query["response_type"], []string{"code"}) assert.ElementsMatch(t, query["client_id"], []string{provider.ClientConfiguration.ClientID}) - assert.ElementsMatch(t, query["redirect_uri"], []string{provider.ClientConfiguration.RedirectURI}) + assert.ElementsMatch(t, query["redirect_uri"], []string{provider.ClientConfiguration.CallbackURI}) assert.ElementsMatch(t, query["scope"], []string{provider.ClientConfiguration.GetScopes().String()}) assert.ElementsMatch(t, query["state"], []string{params.State}) assert.ElementsMatch(t, query["nonce"], []string{params.Nonce}) diff --git a/pkg/router/paths/paths.go b/pkg/router/paths/paths.go index cbedc5b..91482d3 100644 --- a/pkg/router/paths/paths.go +++ b/pkg/router/paths/paths.go @@ -5,5 +5,6 @@ const ( Login = "/login" Callback = "/callback" Logout = "/logout" + LogoutCallback = "/logout/callback" FrontChannelLogout = "/logout/frontchannel" ) diff --git a/pkg/router/request/parameters.go b/pkg/router/request/parameters.go index 4478566..d6532f4 100644 --- a/pkg/router/request/parameters.go +++ b/pkg/router/request/parameters.go @@ -1,8 +1,7 @@ package request const ( - LocaleURLParameter = "locale" - PostLogoutRedirectURIParameter = "post_logout_redirect_uri" - RedirectURLParameter = "redirect" - SecurityLevelURLParameter = "level" + LocaleURLParameter = "locale" + RedirectURLParameter = "redirect" + SecurityLevelURLParameter = "level" ) diff --git a/pkg/router/request/request.go b/pkg/router/request/request.go index 772996a..fc0753d 100644 --- a/pkg/router/request/request.go +++ b/pkg/router/request/request.go @@ -93,15 +93,6 @@ func LoginURLParameter(r *http.Request, parameter, fallback string, supported op return value, fmt.Errorf("%w: invalid value for %s=%s", InvalidLoginParameterError, parameter, value) } -func PostLogoutRedirectURI(r *http.Request, fallback string) string { - value := r.URL.Query().Get(PostLogoutRedirectURIParameter) - - if len(value) > 0 { - return value - } - return fallback -} - func refererPath(r *http.Request) string { if len(r.Referer()) == 0 { return "" diff --git a/pkg/router/router.go b/pkg/router/router.go index cb03362..6f009d6 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -25,6 +25,7 @@ func New(handler *Handler) chi.Router { r.Get(paths.Callback, handler.Callback) r.Get(paths.Logout, handler.Logout) r.Get(paths.FrontChannelLogout, handler.FrontChannelLogout) + r.Get(paths.LogoutCallback, handler.LogoutCallback) }) r.HandleFunc("/*", handler.Default) return r diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index bb4bd74..4b06fc7 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -81,7 +81,7 @@ func TestHandler_Login(t *testing.T) { assert.Equal(t, idp.GetClientConfiguration().GetACRValues(), u.Query().Get("acr_values")) assert.Equal(t, idp.GetClientConfiguration().GetUILocales(), u.Query().Get("ui_locales")) assert.Equal(t, idp.GetClientConfiguration().GetClientID(), u.Query().Get("client_id")) - assert.Equal(t, idp.GetClientConfiguration().GetRedirectURI(), u.Query().Get("redirect_uri")) + assert.Equal(t, idp.GetClientConfiguration().GetCallbackURI(), u.Query().Get("redirect_uri")) assert.NotEmpty(t, u.Query().Get("state")) assert.NotEmpty(t, u.Query().Get("nonce")) assert.NotEmpty(t, u.Query().Get("code_challenge")) @@ -106,8 +106,9 @@ func TestHandler_Callback_and_Logout(t *testing.T) { r := router.New(h) server := httptest.NewServer(r) - idp.ClientConfiguration.RedirectURI = server.URL + "/oauth2/callback" + idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback" idp.ClientConfiguration.PostLogoutRedirectURI = server.URL + idp.ClientConfiguration.LogoutCallbackURI = server.URL + "/oauth2/logout/callback" jar, err := cookiejar.New(nil) assert.NoError(t, err) @@ -175,8 +176,10 @@ func TestHandler_Callback_and_Logout(t *testing.T) { cookies = client.Jar.Cookies(logoutURL) sessionCookie = getCookieFromJar(router.SessionCookieName, cookies) + logoutCookie := getCookieFromJar(router.LogoutCookieName, cookies) assert.Nil(t, sessionCookie) + assert.NotNil(t, logoutCookie) // Get endsession endpoint after local logout location = resp.Header.Get("location") @@ -187,11 +190,48 @@ func TestHandler_Callback_and_Logout(t *testing.T) { assert.NoError(t, err) endsessionParams := endsessionURL.Query() - + expectedState := endsessionParams["state"] assert.Equal(t, idpserverURL.Host, endsessionURL.Host) assert.Equal(t, "/endsession", endsessionURL.Path) - assert.Equal(t, endsessionParams["post_logout_redirect_uri"], []string{idp.GetClientConfiguration().GetPostLogoutRedirectURI()}) + assert.Equal(t, endsessionParams["post_logout_redirect_uri"], []string{idp.GetClientConfiguration().GetLogoutCallbackURI()}) assert.NotEmpty(t, endsessionParams["id_token_hint"]) + assert.NotEmpty(t, expectedState) + + // Follow redirect to endsession endpoint at identity provider + resp, err = client.Get(endsessionURL.String()) + assert.NoError(t, err) + assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + defer resp.Body.Close() + + // Get post-logout redirect URI after successful logout at identity provider + location = resp.Header.Get("location") + logoutCallbackURI, err := url.Parse(location) + assert.NoError(t, err) + assert.Contains(t, logoutCallbackURI.String(), idp.ClientConfiguration.GetLogoutCallbackURI()) + logoutCallbackParams := endsessionURL.Query() + actualState := logoutCallbackParams["state"] + + assert.Equal(t, "/oauth2/logout/callback", logoutCallbackURI.Path) + assert.NotEmpty(t, actualState) + assert.Equal(t, expectedState, actualState) + + // Follow redirect back to logout callback + resp, err = client.Get(logoutCallbackURI.String()) + assert.NoError(t, err) + assert.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) + + // Get post-logout redirect URI after redirect back to logout callback + location = resp.Header.Get("location") + postLogoutRedirectURI, err := url.Parse(location) + assert.NoError(t, err) + assert.Equal(t, idp.ClientConfiguration.GetPostLogoutRedirectURI(), postLogoutRedirectURI.String()) + + cookies = client.Jar.Cookies(logoutCallbackURI) + sessionCookie = getCookieFromJar(router.SessionCookieName, cookies) + logoutCookie = getCookieFromJar(router.LogoutCookieName, cookies) + + assert.Nil(t, sessionCookie) + assert.Nil(t, logoutCookie) } func TestHandler_FrontChannelLogout(t *testing.T) { @@ -202,7 +242,7 @@ func TestHandler_FrontChannelLogout(t *testing.T) { r := router.New(h) server := httptest.NewServer(r) - idp.ClientConfiguration.RedirectURI = server.URL + "/oauth2/callback" + idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback" idp.ClientConfiguration.PostLogoutRedirectURI = server.URL jar, err := cookiejar.New(nil) @@ -276,7 +316,7 @@ func TestHandler_SessionStateRequired(t *testing.T) { r := router.New(h) server := httptest.NewServer(r) - idp.ClientConfiguration.RedirectURI = server.URL + "/oauth2/callback" + idp.ClientConfiguration.CallbackURI = server.URL + "/oauth2/callback" idp.ClientConfiguration.PostLogoutRedirectURI = server.URL jar, err := cookiejar.New(nil)