From 66cf08e60229571b4306e8ca6a5aa2e2cdd67e12 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Tue, 12 Jul 2022 15:09:40 +0200 Subject: [PATCH] refactor(openid/logout): simplify logout logic As we already clear any local sessions before redirecting to the Identity Provider, and the callback always redirects to a pre-configured URL, there isn't really any need to maintain and verify state in the logout callback. In other words, the logout callback handler is simply a redirect handler. --- pkg/cookie/cookie.go | 1 - pkg/mock/openid.go | 6 +- pkg/openid/client/client.go | 11 +-- pkg/openid/client/login_callback_test.go | 1 - pkg/openid/client/logout.go | 26 ------- pkg/openid/client/logout_callback.go | 59 ++++------------ pkg/openid/client/logout_callback_test.go | 83 ++++++----------------- pkg/openid/client/logout_test.go | 74 +++++++------------- pkg/openid/cookies.go | 5 -- pkg/router/handler_logout.go | 29 +------- pkg/router/handler_logout_callback.go | 45 +----------- pkg/router/middleware/logentry.go | 3 +- pkg/router/router_test.go | 10 --- 13 files changed, 66 insertions(+), 287 deletions(-) diff --git a/pkg/cookie/cookie.go b/pkg/cookie/cookie.go index 241f489..13f2cb4 100644 --- a/pkg/cookie/cookie.go +++ b/pkg/cookie/cookie.go @@ -13,7 +13,6 @@ const ( Session = "io.nais.wonderwall.session" Login = "io.nais.wonderwall.callback" LoginLegacy = "io.nais.wonderwall.callback.legacy" - Logout = "io.nais.wonderwall.logout" ) type Cookie struct { diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index 557b88c..0d71b64 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -318,10 +318,9 @@ func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Request) { query := r.URL.Query() - state := query.Get("state") postLogoutRedirectURI := query.Get("post_logout_redirect_uri") - if state == "" || postLogoutRedirectURI == "" { + if postLogoutRedirectURI == "" { w.WriteHeader(http.StatusBadRequest) w.Write([]byte("missing required parameters")) return @@ -333,9 +332,6 @@ func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Req w.Write([]byte("couldn't parse post_logout_redirect_uri")) return } - v := url.Values{} - v.Set("state", state) - u.RawQuery = v.Encode() http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) } diff --git a/pkg/openid/client/client.go b/pkg/openid/client/client.go index c8cb903..76660b0 100644 --- a/pkg/openid/client/client.go +++ b/pkg/openid/client/client.go @@ -22,7 +22,7 @@ type Client interface { Login(r *http.Request) (Login, error) LoginCallback(r *http.Request, p provider.Provider, cookie *openid.LoginCookie) (LoginCallback, error) Logout() (Logout, error) - LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) (LogoutCallback, error) + LogoutCallback(r *http.Request) LogoutCallback LogoutFrontchannel(r *http.Request) LogoutFrontchannel AuthCodeGrant(ctx context.Context, code string, opts []oauth2.AuthCodeOption) (*oauth2.Token, error) @@ -88,13 +88,8 @@ func (c client) Logout() (Logout, error) { return logout, nil } -func (c client) LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) (LogoutCallback, error) { - logoutCallback, err := NewLogoutCallback(r, cookie) - if err != nil { - return nil, fmt.Errorf("logout/callback: %w", err) - } - - return logoutCallback, nil +func (c client) LogoutCallback(r *http.Request) LogoutCallback { + return NewLogoutCallback(c, r) } func (c client) LogoutFrontchannel(r *http.Request) LogoutFrontchannel { diff --git a/pkg/openid/client/login_callback_test.go b/pkg/openid/client/login_callback_test.go index e456ddb..088a01f 100644 --- a/pkg/openid/client/login_callback_test.go +++ b/pkg/openid/client/login_callback_test.go @@ -160,7 +160,6 @@ func newLoginCallback(t *testing.T, url string, cookie *openid.LoginCookie) (moc cfg := idp.OpenIDConfig cfg.ClientConfig.LogoutCallbackURI = LogoutCallbackURI - cfg.ClientConfig.PostLogoutRedirectURI = PostLogoutRedirectURI cfg.ProviderConfig.EndSessionEndpoint = EndSessionEndpoint loginCallback, err := newTestClientWithConfig(cfg).LoginCallback(req, idp.Provider, cookie) diff --git a/pkg/openid/client/logout.go b/pkg/openid/client/logout.go index 0610017..2609fac 100644 --- a/pkg/openid/client/logout.go +++ b/pkg/openid/client/logout.go @@ -3,34 +3,18 @@ package client import ( "fmt" "net/url" - - "github.com/nais/wonderwall/pkg/openid" - "github.com/nais/wonderwall/pkg/strings" ) type Logout interface { - CanonicalRedirect() string - Cookie() *openid.LogoutCookie SingleLogoutURL(idToken string) string } type logout struct { Client - cookie *openid.LogoutCookie endSessionEndpoint *url.URL } func NewLogout(c Client) (Logout, error) { - state, err := strings.GenerateBase64(32) - if err != nil { - return nil, fmt.Errorf("generating state: %w", err) - } - - cookie := &openid.LogoutCookie{ - State: state, - RedirectTo: c.config().Client().GetPostLogoutRedirectURI(), - } - endSessionEndpoint, err := url.Parse(c.config().Provider().EndSessionEndpoint) if err != nil { return nil, fmt.Errorf("parsing end session endpoint: %w", err) @@ -38,23 +22,13 @@ func NewLogout(c Client) (Logout, error) { return &logout{ Client: c, - cookie: cookie, endSessionEndpoint: endSessionEndpoint, }, nil } -func (in logout) CanonicalRedirect() string { - return in.cookie.RedirectTo -} - -func (in logout) Cookie() *openid.LogoutCookie { - return in.cookie -} - func (in logout) SingleLogoutURL(idToken string) string { v := in.endSessionEndpoint.Query() v.Add("post_logout_redirect_uri", in.config().Client().GetLogoutCallbackURI()) - v.Add("state", in.cookie.State) if len(idToken) > 0 { v.Add("id_token_hint", idToken) diff --git a/pkg/openid/client/logout_callback.go b/pkg/openid/client/logout_callback.go index 61bcf69..9fd76b4 100644 --- a/pkg/openid/client/logout_callback.go +++ b/pkg/openid/client/logout_callback.go @@ -1,64 +1,31 @@ package client import ( - "fmt" "net/http" - "net/url" - - "github.com/nais/wonderwall/pkg/openid" ) type LogoutCallback interface { - ValidateRequest() error + PostLogoutRedirectURI() string } type logoutCallback struct { - cookie *openid.LogoutCookie - requestParams url.Values + Client + request *http.Request } -func NewLogoutCallback(r *http.Request, cookie *openid.LogoutCookie) (LogoutCallback, error) { - if cookie == nil { - return nil, fmt.Errorf("cookie is nil") - } - +func NewLogoutCallback(c Client, r *http.Request) LogoutCallback { return &logoutCallback{ - requestParams: r.URL.Query(), - cookie: cookie, - }, nil + Client: c, + request: r, + } } -func (in logoutCallback) ValidateRequest() error { - if err := in.emptyRedirectError(); err != nil { - return err +func (in logoutCallback) PostLogoutRedirectURI() string { + redirect := in.config().Client().GetPostLogoutRedirectURI() + + if len(redirect) == 0 { + return in.config().Wonderwall().Ingress } - if err := in.stateMismatchError(); err != nil { - return err - } - - return nil -} - -func (in logoutCallback) emptyRedirectError() error { - if len(in.cookie.RedirectTo) == 0 { - return fmt.Errorf("empty redirect") - } - - return nil -} - -func (in logoutCallback) stateMismatchError() error { - expectedState := in.cookie.State - actualState := in.requestParams.Get("state") - - if len(actualState) <= 0 { - return fmt.Errorf("missing state parameter in request (possible csrf)") - } - - if expectedState != actualState { - return fmt.Errorf("state parameter mismatch (possible csrf): expected %s, got %s", expectedState, actualState) - } - - return nil + return redirect } diff --git a/pkg/openid/client/logout_callback_test.go b/pkg/openid/client/logout_callback_test.go index 57cd692..eca5aae 100644 --- a/pkg/openid/client/logout_callback_test.go +++ b/pkg/openid/client/logout_callback_test.go @@ -6,76 +6,35 @@ import ( "github.com/stretchr/testify/assert" - "github.com/nais/wonderwall/pkg/openid" + "github.com/nais/wonderwall/pkg/mock" "github.com/nais/wonderwall/pkg/openid/client" ) -func TestLogoutCallback_ValidateRequest(t *testing.T) { - t.Run("nil cookie", func(t *testing.T) { - _, err := newLogoutCallback(t, "http://localhost/oauth2/logout/callback?state=some-state", nil) - assert.Error(t, err) +func TestLogoutCallback_PostLogoutRedirectURI(t *testing.T) { + t.Run("happy path", func(t *testing.T) { + lc, cfg := newLogoutCallback(t) + cfg.ClientConfig.PostLogoutRedirectURI = "http://some-fancy-logout-page" + + uri := lc.PostLogoutRedirectURI() + assert.NotEmpty(t, uri) + assert.Equal(t, "http://some-fancy-logout-page", uri) }) - for _, test := range []struct { - name string - url string - cookie *openid.LogoutCookie - wantErr bool - }{ - { - name: "valid request", - url: "http://localhost/oauth2/logout/callback?state=some-state", - cookie: &openid.LogoutCookie{ - State: "some-state", - RedirectTo: "http://some-url", - }, - wantErr: false, - }, - { - name: "empty redirect", - url: "http://localhost/oauth2/logout/callback?state=some-state", - cookie: &openid.LogoutCookie{ - State: "some-state", - RedirectTo: "", - }, - wantErr: true, - }, - { - name: "empty state", - url: "http://localhost/oauth2/logout/callback", - cookie: &openid.LogoutCookie{ - State: "some-state", - RedirectTo: "http://some-url", - }, - wantErr: true, - }, - { - name: "state mismatch", - url: "http://localhost/oauth2/logout/callback?state=some-other-state", - cookie: &openid.LogoutCookie{ - State: "some-state", - RedirectTo: "http://some-url", - }, - wantErr: true, - }, - } { - t.Run(test.name, func(t *testing.T) { - lc, err := newLogoutCallback(t, test.url, test.cookie) - assert.NoError(t, err) + t.Run("empty preconfigured post-logout redirect uri", func(t *testing.T) { + lc, cfg := newLogoutCallback(t) + cfg.ClientConfig.PostLogoutRedirectURI = "" + cfg.WonderwallConfig.Ingress = "http://wonderwall" - err = lc.ValidateRequest() - if test.wantErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - }) - } + uri := lc.PostLogoutRedirectURI() + assert.NotEmpty(t, uri) + assert.Equal(t, "http://wonderwall", uri) + }) } -func newLogoutCallback(t *testing.T, url string, cookie *openid.LogoutCookie) (client.LogoutCallback, error) { - req, err := http.NewRequest("GET", url, nil) +func newLogoutCallback(t *testing.T) (client.LogoutCallback, mock.Configuration) { + req, err := http.NewRequest("GET", "http://wonderwall/oauth2/logout/callback", nil) assert.NoError(t, err) - return newTestClient().LogoutCallback(req, cookie) + cfg := mock.NewTestConfiguration(mock.Config()) + return newTestClientWithConfig(cfg).LogoutCallback(req), cfg } diff --git a/pkg/openid/client/logout_test.go b/pkg/openid/client/logout_test.go index ebef131..975c9c9 100644 --- a/pkg/openid/client/logout_test.go +++ b/pkg/openid/client/logout_test.go @@ -16,28 +16,9 @@ const ( EndSessionEndpoint = "http://provider/endsession" ) -func TestLogout_CanonicalRedirect(t *testing.T) { - logout := newLogout(t) - canonicalRedirect := logout.CanonicalRedirect() - - assert.Equal(t, PostLogoutRedirectURI, canonicalRedirect) -} - -func TestLogout_Cookie(t *testing.T) { - logout := newLogout(t) - cookie := logout.Cookie() - - assert.NotNil(t, cookie) - assert.NotEmpty(t, cookie.State) - assert.NotEmpty(t, cookie.RedirectTo) -} - func TestLogout_SingleLogoutURL(t *testing.T) { t.Run("with id_token", func(t *testing.T) { logout := newLogout(t) - cookie := logout.Cookie() - - state := cookie.State idToken := "some-id-token" raw := logout.SingleLogoutURL(idToken) @@ -46,43 +27,34 @@ func TestLogout_SingleLogoutURL(t *testing.T) { logoutUrl, err := url.Parse(raw) assert.NoError(t, err) - query := logoutUrl.Query() - assert.Len(t, query, 3) - - assert.Contains(t, query, "id_token_hint") - assert.Equal(t, idToken, query.Get("id_token_hint")) - - assert.Contains(t, query, "state") - assert.Equal(t, state, query.Get("state")) - - assert.Contains(t, query, "post_logout_redirect_uri") - assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri")) - - logoutUrl.RawQuery = "" - assert.Equal(t, EndSessionEndpoint, logoutUrl.String()) - }) - - t.Run("without id_token", func(t *testing.T) { - logout := newLogout(t) - cookie := logout.Cookie() - - state := cookie.State - idToken := "" - - raw := logout.SingleLogoutURL(idToken) - assert.NotEmpty(t, raw) - - logoutUrl, err := url.Parse(raw) - assert.NoError(t, err) - query := logoutUrl.Query() assert.Len(t, query, 2) - assert.NotContains(t, query, "id_token_hint") + assert.Contains(t, query, "id_token_hint") assert.Equal(t, idToken, query.Get("id_token_hint")) - assert.Contains(t, query, "state") - assert.Equal(t, state, query.Get("state")) + assert.Contains(t, query, "post_logout_redirect_uri") + assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri")) + + logoutUrl.RawQuery = "" + assert.Equal(t, EndSessionEndpoint, logoutUrl.String()) + }) + + t.Run("without id_token", func(t *testing.T) { + logout := newLogout(t) + idToken := "" + + raw := logout.SingleLogoutURL(idToken) + assert.NotEmpty(t, raw) + + logoutUrl, err := url.Parse(raw) + assert.NoError(t, err) + + query := logoutUrl.Query() + assert.Len(t, query, 1) + + assert.NotContains(t, query, "id_token_hint") + assert.Equal(t, idToken, query.Get("id_token_hint")) assert.Contains(t, query, "post_logout_redirect_uri") assert.Equal(t, LogoutCallbackURI, query.Get("post_logout_redirect_uri")) diff --git a/pkg/openid/cookies.go b/pkg/openid/cookies.go index aabcaeb..71167e2 100644 --- a/pkg/openid/cookies.go +++ b/pkg/openid/cookies.go @@ -6,8 +6,3 @@ type LoginCookie struct { CodeVerifier string `json:"code_verifier"` Referer string `json:"referer"` } - -type LogoutCookie struct { - State string `json:"state"` - RedirectTo string `json:"redirect_to"` -} diff --git a/pkg/router/handler_logout.go b/pkg/router/handler_logout.go index 7adccb9..e8e5661 100644 --- a/pkg/router/handler_logout.go +++ b/pkg/router/handler_logout.go @@ -1,23 +1,16 @@ package router import ( - "encoding/json" "errors" "fmt" "net/http" - "time" "github.com/go-redis/redis/v8" "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/openid" logentry "github.com/nais/wonderwall/pkg/router/middleware" ) -const ( - LogoutCookieLifetime = 5 * time.Minute -) - // Logout triggers self-initiated for the current user func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { var idToken string @@ -47,31 +40,11 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { logout, err := h.Client.Logout() if err != nil { h.InternalError(w, r, err) - } - - err = h.setLogoutCookie(w, logout.Cookie()) - if err != nil { - h.InternalError(w, r, fmt.Errorf("logout: setting logout cookie: %w", err)) return } - fields := map[string]interface{}{ - "redirect_to": logout.CanonicalRedirect(), - } - logger := logentry.LogEntryWithFields(r.Context(), fields) + logger := logentry.LogEntry(r.Context()) logger.Info().Msg("logout: redirecting to identity provider") http.Redirect(w, r, logout.SingleLogoutURL(idToken), http.StatusTemporaryRedirect) } - -func (h *Handler) setLogoutCookie(w http.ResponseWriter, logoutCookie *openid.LogoutCookie) error { - logoutCookieJson, err := json.Marshal(logoutCookie) - if err != nil { - return fmt.Errorf("marshalling login cookie: %w", err) - } - - opts := h.CookieOptions.WithExpiresIn(LogoutCookieLifetime) - value := string(logoutCookieJson) - - return cookie.EncryptAndSet(w, cookie.Logout, value, opts, h.Crypter) -} diff --git a/pkg/router/handler_logout_callback.go b/pkg/router/handler_logout_callback.go index 714e63c..1aa7d02 100644 --- a/pkg/router/handler_logout_callback.go +++ b/pkg/router/handler_logout_callback.go @@ -1,55 +1,16 @@ package router import ( - "encoding/json" - "fmt" "net/http" - "github.com/nais/wonderwall/pkg/cookie" - "github.com/nais/wonderwall/pkg/openid" logentry "github.com/nais/wonderwall/pkg/router/middleware" ) // LogoutCallback handles the callback from the self-initiated logout for the current user func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) { - cookie.Clear(w, cookie.Logout, h.CookieOptions) + redirect := h.Client.LogoutCallback(r).PostLogoutRedirectURI() logger := logentry.LogEntry(r.Context()) - - logoutCookie, err := h.getLogoutCookie(r) - if err != nil { - logger.Warn().Msgf("logout/callback: getting cookie: %+v", err) - http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect) - return - } - - logoutCallback, err := h.Client.LogoutCallback(r, logoutCookie) - if err != nil { - h.InternalError(w, r, err) - return - } - - if err := logoutCallback.ValidateRequest(); err != nil { - logger.Warn().Msgf("logout/callback: %+v; falling back to ingress", err) - http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect) - return - } - - logger.Info().Msgf("logout/callback: redirecting to %s", logoutCookie.RedirectTo) - http.Redirect(w, r, logoutCookie.RedirectTo, http.StatusTemporaryRedirect) -} - -func (h *Handler) getLogoutCookie(r *http.Request) (*openid.LogoutCookie, error) { - logoutCookieJson, err := cookie.GetDecrypted(r, cookie.Logout, h.Crypter) - if err != nil { - return nil, err - } - - var logoutCookie openid.LogoutCookie - err = json.Unmarshal([]byte(logoutCookieJson), &logoutCookie) - if err != nil { - return nil, fmt.Errorf("unmarshalling: %w", err) - } - - return &logoutCookie, nil + logger.Info().Msgf("logout/callback: redirecting to %s", redirect) + http.Redirect(w, r, redirect, http.StatusTemporaryRedirect) } diff --git a/pkg/router/middleware/logentry.go b/pkg/router/middleware/logentry.go index cd8ced3..25cbcf8 100644 --- a/pkg/router/middleware/logentry.go +++ b/pkg/router/middleware/logentry.go @@ -175,8 +175,7 @@ func isRelevantCookie(name string) bool { switch name { case cookie.Session, cookie.Login, - cookie.LoginLegacy, - cookie.Logout: + cookie.LoginLegacy: return true } diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index 61043c0..3445090 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -133,10 +133,8 @@ func TestHandler_Callback_and_Logout(t *testing.T) { cookies = rpClient.Jar.Cookies(logoutURL) sessionCookie = getCookieFromJar(cookie.Session, cookies) - logoutCookie := getCookieFromJar(cookie.Logout, cookies) assert.Nil(t, sessionCookie) - assert.NotNil(t, logoutCookie) // Get endsession endpoint after local logout location = resp.Header.Get("location") @@ -147,12 +145,10 @@ func TestHandler_Callback_and_Logout(t *testing.T) { assert.NoError(t, err) endsessionParams := endsessionURL.Query() - expectedState := endsessionParams["state"] assert.Equal(t, idpserverURL.Host, endsessionURL.Host) assert.Equal(t, "/endsession", endsessionURL.Path) assert.Equal(t, endsessionParams["post_logout_redirect_uri"], []string{idp.OpenIDConfig.Client().GetLogoutCallbackURI()}) assert.NotEmpty(t, endsessionParams["id_token_hint"]) - assert.NotEmpty(t, expectedState) // Follow redirect to endsession endpoint at identity provider resp, err = rpClient.Get(endsessionURL.String()) @@ -165,12 +161,8 @@ func TestHandler_Callback_and_Logout(t *testing.T) { logoutCallbackURI, err := url.Parse(location) assert.NoError(t, err) assert.Contains(t, logoutCallbackURI.String(), idp.OpenIDConfig.Client().GetLogoutCallbackURI()) - logoutCallbackParams := endsessionURL.Query() - actualState := logoutCallbackParams["state"] assert.Equal(t, "/oauth2/logout/callback", logoutCallbackURI.Path) - assert.NotEmpty(t, actualState) - assert.Equal(t, expectedState, actualState) // Follow redirect back to logout callback resp, err = rpClient.Get(logoutCallbackURI.String()) @@ -185,10 +177,8 @@ func TestHandler_Callback_and_Logout(t *testing.T) { cookies = rpClient.Jar.Cookies(logoutCallbackURI) sessionCookie = getCookieFromJar(cookie.Session, cookies) - logoutCookie = getCookieFromJar(cookie.Logout, cookies) assert.Nil(t, sessionCookie) - assert.Nil(t, logoutCookie) } func TestHandler_FrontChannelLogout(t *testing.T) {