diff --git a/pkg/openid/client/client.go b/pkg/openid/client/client.go index 5a34dd7..c8cb903 100644 --- a/pkg/openid/client/client.go +++ b/pkg/openid/client/client.go @@ -20,9 +20,9 @@ type Client interface { oAuth2Config() *oauth2.Config Login(r *http.Request) (Login, error) - LoginCallback(r *http.Request, p provider.Provider, cookie *openid.LoginCookie) LoginCallback + LoginCallback(r *http.Request, p provider.Provider, cookie *openid.LoginCookie) (LoginCallback, error) Logout() (Logout, error) - LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) LogoutCallback + LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) (LogoutCallback, error) LogoutFrontchannel(r *http.Request) LogoutFrontchannel AuthCodeGrant(ctx context.Context, code string, opts []oauth2.AuthCodeOption) (*oauth2.Token, error) @@ -70,8 +70,13 @@ func (c client) Login(r *http.Request) (Login, error) { return login, nil } -func (c client) LoginCallback(r *http.Request, p provider.Provider, cookie *openid.LoginCookie) LoginCallback { - return NewLoginCallback(c, r, p, cookie) +func (c client) LoginCallback(r *http.Request, p provider.Provider, cookie *openid.LoginCookie) (LoginCallback, error) { + loginCallback, err := NewLoginCallback(c, r, p, cookie) + if err != nil { + return nil, fmt.Errorf("callback: %w", err) + } + + return loginCallback, nil } func (c client) Logout() (Logout, error) { @@ -83,8 +88,13 @@ func (c client) Logout() (Logout, error) { return logout, nil } -func (c client) LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) LogoutCallback { - return NewLogoutCallback(r, cookie) +func (c client) LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) (LogoutCallback, error) { + logoutCallback, err := NewLogoutCallback(r, cookie) + if err != nil { + return nil, fmt.Errorf("logout/callback: %w", err) + } + + return logoutCallback, nil } func (c client) LogoutFrontchannel(r *http.Request) LogoutFrontchannel { diff --git a/pkg/openid/client/login_callback.go b/pkg/openid/client/login_callback.go index 233d9e7..8e59350 100644 --- a/pkg/openid/client/login_callback.go +++ b/pkg/openid/client/login_callback.go @@ -29,14 +29,18 @@ type loginCallback struct { requestParams url.Values } -func NewLoginCallback(c Client, r *http.Request, p provider.Provider, cookie *openid.LoginCookie) LoginCallback { +func NewLoginCallback(c Client, r *http.Request, p provider.Provider, cookie *openid.LoginCookie) (LoginCallback, error) { + if cookie == nil { + return nil, fmt.Errorf("cookie is nil") + } + return &loginCallback{ client: c, cookie: cookie, provider: p, request: r, requestParams: r.URL.Query(), - } + }, nil } func (in loginCallback) IdentityProviderError() error { diff --git a/pkg/openid/client/logout_callback.go b/pkg/openid/client/logout_callback.go index f25b6f5..61bcf69 100644 --- a/pkg/openid/client/logout_callback.go +++ b/pkg/openid/client/logout_callback.go @@ -17,11 +17,15 @@ type logoutCallback struct { requestParams url.Values } -func NewLogoutCallback(r *http.Request, cookie *openid.LogoutCookie) LogoutCallback { +func NewLogoutCallback(r *http.Request, cookie *openid.LogoutCookie) (LogoutCallback, error) { + if cookie == nil { + return nil, fmt.Errorf("cookie is nil") + } + return &logoutCallback{ requestParams: r.URL.Query(), cookie: cookie, - } + }, nil } func (in logoutCallback) ValidateRequest() error { diff --git a/pkg/router/handler_callback.go b/pkg/router/handler_callback.go index 12d072a..6cfdb3a 100644 --- a/pkg/router/handler_callback.go +++ b/pkg/router/handler_callback.go @@ -36,7 +36,11 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - loginCallback := h.Client.LoginCallback(r, h.Provider, loginCookie) + loginCallback, err := h.Client.LoginCallback(r, h.Provider, loginCookie) + if err != nil { + h.InternalError(w, r, err) + return + } if err := loginCallback.IdentityProviderError(); err != nil { h.InternalError(w, r, fmt.Errorf("callback: %w", err)) diff --git a/pkg/router/handler_logout_callback.go b/pkg/router/handler_logout_callback.go index 2e0a810..714e63c 100644 --- a/pkg/router/handler_logout_callback.go +++ b/pkg/router/handler_logout_callback.go @@ -23,7 +23,12 @@ func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) { return } - logoutCallback := h.Client.LogoutCallback(r, logoutCookie) + logoutCallback, err := h.Client.LogoutCallback(r, logoutCookie) + if err != nil { + h.InternalError(w, r, err) + return + } + if err := logoutCallback.ValidateRequest(); err != nil { logger.Warn().Msgf("logout/callback: %+v; falling back to ingress", err) http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect)