refactor(openid/client): ensure callback cookies are not nil

This commit is contained in:
Trong Huu Nguyen
2022-07-11 14:30:04 +02:00
parent 48160e7986
commit b937c64dd6
5 changed files with 39 additions and 12 deletions

View File

@@ -20,9 +20,9 @@ type Client interface {
oAuth2Config() *oauth2.Config
Login(r *http.Request) (Login, error)
LoginCallback(r *http.Request, p provider.Provider, cookie *openid.LoginCookie) LoginCallback
LoginCallback(r *http.Request, p provider.Provider, cookie *openid.LoginCookie) (LoginCallback, error)
Logout() (Logout, error)
LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) LogoutCallback
LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) (LogoutCallback, error)
LogoutFrontchannel(r *http.Request) LogoutFrontchannel
AuthCodeGrant(ctx context.Context, code string, opts []oauth2.AuthCodeOption) (*oauth2.Token, error)
@@ -70,8 +70,13 @@ func (c client) Login(r *http.Request) (Login, error) {
return login, nil
}
func (c client) LoginCallback(r *http.Request, p provider.Provider, cookie *openid.LoginCookie) LoginCallback {
return NewLoginCallback(c, r, p, cookie)
func (c client) LoginCallback(r *http.Request, p provider.Provider, cookie *openid.LoginCookie) (LoginCallback, error) {
loginCallback, err := NewLoginCallback(c, r, p, cookie)
if err != nil {
return nil, fmt.Errorf("callback: %w", err)
}
return loginCallback, nil
}
func (c client) Logout() (Logout, error) {
@@ -83,8 +88,13 @@ func (c client) Logout() (Logout, error) {
return logout, nil
}
func (c client) LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) LogoutCallback {
return NewLogoutCallback(r, cookie)
func (c client) LogoutCallback(r *http.Request, cookie *openid.LogoutCookie) (LogoutCallback, error) {
logoutCallback, err := NewLogoutCallback(r, cookie)
if err != nil {
return nil, fmt.Errorf("logout/callback: %w", err)
}
return logoutCallback, nil
}
func (c client) LogoutFrontchannel(r *http.Request) LogoutFrontchannel {

View File

@@ -29,14 +29,18 @@ type loginCallback struct {
requestParams url.Values
}
func NewLoginCallback(c Client, r *http.Request, p provider.Provider, cookie *openid.LoginCookie) LoginCallback {
func NewLoginCallback(c Client, r *http.Request, p provider.Provider, cookie *openid.LoginCookie) (LoginCallback, error) {
if cookie == nil {
return nil, fmt.Errorf("cookie is nil")
}
return &loginCallback{
client: c,
cookie: cookie,
provider: p,
request: r,
requestParams: r.URL.Query(),
}
}, nil
}
func (in loginCallback) IdentityProviderError() error {

View File

@@ -17,11 +17,15 @@ type logoutCallback struct {
requestParams url.Values
}
func NewLogoutCallback(r *http.Request, cookie *openid.LogoutCookie) LogoutCallback {
func NewLogoutCallback(r *http.Request, cookie *openid.LogoutCookie) (LogoutCallback, error) {
if cookie == nil {
return nil, fmt.Errorf("cookie is nil")
}
return &logoutCallback{
requestParams: r.URL.Query(),
cookie: cookie,
}
}, nil
}
func (in logoutCallback) ValidateRequest() error {

View File

@@ -36,7 +36,11 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
loginCallback := h.Client.LoginCallback(r, h.Provider, loginCookie)
loginCallback, err := h.Client.LoginCallback(r, h.Provider, loginCookie)
if err != nil {
h.InternalError(w, r, err)
return
}
if err := loginCallback.IdentityProviderError(); err != nil {
h.InternalError(w, r, fmt.Errorf("callback: %w", err))

View File

@@ -23,7 +23,12 @@ func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) {
return
}
logoutCallback := h.Client.LogoutCallback(r, logoutCookie)
logoutCallback, err := h.Client.LogoutCallback(r, logoutCookie)
if err != nil {
h.InternalError(w, r, err)
return
}
if err := logoutCallback.ValidateRequest(); err != nil {
logger.Warn().Msgf("logout/callback: %+v; falling back to ingress", err)
http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect)