From b770f2217434b37557b33e490b59fee8bf43af76 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Mon, 11 Jul 2022 13:37:40 +0200 Subject: [PATCH] refactor(handler/logoutcallback): extract to openid client --- pkg/openid/client/client.go | 7 ++-- pkg/openid/client/logout_callback.go | 60 ++++++++++++++++++++++++--- pkg/router/handler_logout_callback.go | 15 ++----- 3 files changed, 61 insertions(+), 21 deletions(-) diff --git a/pkg/openid/client/client.go b/pkg/openid/client/client.go index 4fb4b30..5a34dd7 100644 --- a/pkg/openid/client/client.go +++ b/pkg/openid/client/client.go @@ -22,7 +22,7 @@ type Client interface { Login(r *http.Request) (Login, error) LoginCallback(r *http.Request, p provider.Provider, cookie *openid.LoginCookie) LoginCallback Logout() (Logout, error) - LogoutCallback(r *http.Request) error + LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) LogoutCallback LogoutFrontchannel(r *http.Request) LogoutFrontchannel AuthCodeGrant(ctx context.Context, code string, opts []oauth2.AuthCodeOption) (*oauth2.Token, error) @@ -83,9 +83,8 @@ func (c client) Logout() (Logout, error) { return logout, nil } -func (c client) LogoutCallback(r *http.Request) error { - //TODO implement me - panic("implement me") +func (c client) LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) LogoutCallback { + return NewLogoutCallback(r, cookie) } func (c client) LogoutFrontchannel(r *http.Request) LogoutFrontchannel { diff --git a/pkg/openid/client/logout_callback.go b/pkg/openid/client/logout_callback.go index 4c1304a..f25b6f5 100644 --- a/pkg/openid/client/logout_callback.go +++ b/pkg/openid/client/logout_callback.go @@ -1,10 +1,60 @@ package client -type LogoutCallback struct { - Client +import ( + "fmt" + "net/http" + "net/url" + + "github.com/nais/wonderwall/pkg/openid" +) + +type LogoutCallback interface { + ValidateRequest() error } -func (in LogoutCallback) ValidateRequest() (bool, error) { - // TODO - panic("not implemented") +type logoutCallback struct { + cookie *openid.LogoutCookie + requestParams url.Values +} + +func NewLogoutCallback(r *http.Request, cookie *openid.LogoutCookie) LogoutCallback { + return &logoutCallback{ + requestParams: r.URL.Query(), + cookie: cookie, + } +} + +func (in logoutCallback) ValidateRequest() error { + if err := in.emptyRedirectError(); err != nil { + return err + } + + if err := in.stateMismatchError(); err != nil { + return err + } + + return nil +} + +func (in logoutCallback) emptyRedirectError() error { + if len(in.cookie.RedirectTo) == 0 { + return fmt.Errorf("empty redirect") + } + + return nil +} + +func (in logoutCallback) stateMismatchError() error { + expectedState := in.cookie.State + actualState := in.requestParams.Get("state") + + if len(actualState) <= 0 { + return fmt.Errorf("missing state parameter in request (possible csrf)") + } + + if expectedState != actualState { + return fmt.Errorf("state parameter mismatch (possible csrf): expected %s, got %s", expectedState, actualState) + } + + return nil } diff --git a/pkg/router/handler_logout_callback.go b/pkg/router/handler_logout_callback.go index 8e18e47..2e0a810 100644 --- a/pkg/router/handler_logout_callback.go +++ b/pkg/router/handler_logout_callback.go @@ -23,18 +23,9 @@ func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) { return } - params := r.URL.Query() - expectedState := logoutCookie.State - actualState := params.Get("state") - - if expectedState != actualState { - logger.Warn().Msgf("logout/callback: state parameter mismatch: expected %s, got %s; falling back to ingress", expectedState, actualState) - http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect) - return - } - - if len(logoutCookie.RedirectTo) == 0 { - logger.Warn().Msgf("logout/callback: empty redirect; falling back to ingress") + logoutCallback := h.Client.LogoutCallback(r, logoutCookie) + if err := logoutCallback.ValidateRequest(); err != nil { + logger.Warn().Msgf("logout/callback: %+v; falling back to ingress", err) http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect) return }