mirror of
https://github.com/nais/wonderwall.git
synced 2026-02-14 17:49:54 +00:00
refactor(handler): inline error handler, remove unnecessary getters
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
package error
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/handler/templates"
|
||||
mw "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
@@ -24,45 +23,30 @@ const (
|
||||
MaxAutoRetryAttempts = 3
|
||||
)
|
||||
|
||||
type Source interface {
|
||||
GetCookieOptions(r *http.Request) cookie.Options
|
||||
GetCrypter() crypto.Crypter
|
||||
GetPath(r *http.Request) string
|
||||
GetRedirect() urlpkg.Redirect
|
||||
}
|
||||
|
||||
type Page struct {
|
||||
CorrelationID string
|
||||
RetryURI string
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
Source
|
||||
func (s *Standalone) InternalError(w http.ResponseWriter, r *http.Request, cause error) {
|
||||
s.respondError(w, r, http.StatusInternalServerError, cause, log.ErrorLevel)
|
||||
}
|
||||
|
||||
func New(src Source) Handler {
|
||||
return Handler{src}
|
||||
func (s *Standalone) BadRequest(w http.ResponseWriter, r *http.Request, cause error) {
|
||||
s.respondError(w, r, http.StatusBadRequest, cause, log.ErrorLevel)
|
||||
}
|
||||
|
||||
func (h Handler) InternalError(w http.ResponseWriter, r *http.Request, cause error) {
|
||||
h.respondError(w, r, http.StatusInternalServerError, cause, log.ErrorLevel)
|
||||
}
|
||||
|
||||
func (h Handler) BadRequest(w http.ResponseWriter, r *http.Request, cause error) {
|
||||
h.respondError(w, r, http.StatusBadRequest, cause, log.ErrorLevel)
|
||||
}
|
||||
|
||||
func (h Handler) Unauthorized(w http.ResponseWriter, r *http.Request, cause error) {
|
||||
h.respondError(w, r, http.StatusUnauthorized, cause, log.WarnLevel)
|
||||
func (s *Standalone) Unauthorized(w http.ResponseWriter, r *http.Request, cause error) {
|
||||
s.respondError(w, r, http.StatusUnauthorized, cause, log.WarnLevel)
|
||||
}
|
||||
|
||||
// Retry returns a URI that should retry the desired route that failed.
|
||||
// It only handles the routes exposed by Wonderwall, i.e. `/oauth2/*`. As these routes
|
||||
// are related to the authentication flow, we default to redirecting back to the handled
|
||||
// `/oauth2/login` endpoint unless the original request attempted to reach the logout-flow.
|
||||
func (h Handler) Retry(r *http.Request, loginCookie *openid.LoginCookie) string {
|
||||
func (s *Standalone) Retry(r *http.Request, loginCookie *openid.LoginCookie) string {
|
||||
requestPath := r.URL.Path
|
||||
ingressPath := h.GetPath(r)
|
||||
ingressPath := s.GetPath(r)
|
||||
|
||||
for _, path := range []string{paths.Logout, paths.LogoutLocal, paths.LogoutFrontChannel} {
|
||||
if strings.HasSuffix(requestPath, paths.OAuth2+path) {
|
||||
@@ -70,28 +54,28 @@ func (h Handler) Retry(r *http.Request, loginCookie *openid.LoginCookie) string
|
||||
}
|
||||
}
|
||||
|
||||
redirect := h.GetRedirect().Canonical(r)
|
||||
redirect := s.Redirect.Canonical(r)
|
||||
if loginCookie != nil && len(loginCookie.Referer) > 0 {
|
||||
redirect = h.GetRedirect().Clean(r, loginCookie.Referer)
|
||||
redirect = s.Redirect.Clean(r, loginCookie.Referer)
|
||||
}
|
||||
|
||||
return urlpkg.LoginRelative(ingressPath, redirect)
|
||||
}
|
||||
|
||||
func (h Handler) respondError(w http.ResponseWriter, r *http.Request, statusCode int, cause error, level log.Level) {
|
||||
func (s *Standalone) respondError(w http.ResponseWriter, r *http.Request, statusCode int, cause error, level log.Level) {
|
||||
logger := mw.LogEntryFrom(r)
|
||||
msg := "error in route: %+v"
|
||||
|
||||
incrementRetryAttempt(w, r, h.GetCookieOptions(r))
|
||||
incrementRetryAttempt(w, r, s.GetCookieOptions(r))
|
||||
|
||||
attempts, ok := getRetryAttempts(r)
|
||||
if !ok || ok && attempts < MaxAutoRetryAttempts {
|
||||
loginCookie, err := openid.GetLoginCookie(r, h.GetCrypter())
|
||||
loginCookie, err := openid.GetLoginCookie(r, s.Crypter)
|
||||
if err != nil {
|
||||
loginCookie = nil
|
||||
}
|
||||
|
||||
retryUri := h.Retry(r, loginCookie)
|
||||
retryUri := s.Retry(r, loginCookie)
|
||||
logger.Warnf(msg, cause)
|
||||
|
||||
logger.Infof("errorhandler: auto-retry (attempt %d/%d) redirecting to %q...", attempts+1, MaxAutoRetryAttempts, retryUri)
|
||||
@@ -107,20 +91,20 @@ func (h Handler) respondError(w http.ResponseWriter, r *http.Request, statusCode
|
||||
}
|
||||
|
||||
logger.Info("errorhandler: maximum retry attempts exceeded; executing error template...")
|
||||
h.defaultErrorResponse(w, r, statusCode)
|
||||
s.defaultErrorResponse(w, r, statusCode)
|
||||
}
|
||||
|
||||
func (h Handler) defaultErrorResponse(w http.ResponseWriter, r *http.Request, statusCode int) {
|
||||
func (s *Standalone) defaultErrorResponse(w http.ResponseWriter, r *http.Request, statusCode int) {
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
loginCookie, err := openid.GetLoginCookie(r, h.GetCrypter())
|
||||
loginCookie, err := openid.GetLoginCookie(r, s.Crypter)
|
||||
if err != nil {
|
||||
loginCookie = nil
|
||||
}
|
||||
|
||||
errorPage := Page{
|
||||
CorrelationID: middleware.GetReqID(r.Context()),
|
||||
RetryURI: h.Retry(r, loginCookie),
|
||||
RetryURI: s.Retry(r, loginCookie),
|
||||
}
|
||||
err = templates.ErrorTemplate.Execute(w, errorPage)
|
||||
if err != nil {
|
||||
@@ -1,4 +1,4 @@
|
||||
package error_test
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
errorhandler "github.com/nais/wonderwall/pkg/handler/error"
|
||||
errorhandler "github.com/nais/wonderwall/pkg/handler"
|
||||
"github.com/nais/wonderwall/pkg/ingress"
|
||||
mw "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/mock"
|
||||
@@ -21,7 +21,7 @@ func TestHandler_Error(t *testing.T) {
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
rpHandler := idp.RelyingPartyHandler.GetErrorHandler()
|
||||
rpHandler := idp.RelyingPartyHandler
|
||||
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
@@ -193,14 +193,12 @@ func TestHandler_Retry(t *testing.T) {
|
||||
idp := mock.NewIdentityProvider(cfg)
|
||||
defer idp.Close()
|
||||
|
||||
handler := idp.RelyingPartyHandler.GetErrorHandler()
|
||||
|
||||
ing, err := ingress.ParseIngress(test.ingress)
|
||||
assert.NoError(t, err)
|
||||
|
||||
test.request = mw.RequestWithPath(test.request, ing.Path())
|
||||
|
||||
retryURI := handler.Retry(test.request, test.loginCookie)
|
||||
retryURI := idp.RelyingPartyHandler.Retry(test.request, test.loginCookie)
|
||||
assert.Equal(t, test.want, retryURI)
|
||||
})
|
||||
}
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
"github.com/nais/wonderwall/pkg/handler/autologin"
|
||||
errorhandler "github.com/nais/wonderwall/pkg/handler/error"
|
||||
"github.com/nais/wonderwall/pkg/ingress"
|
||||
"github.com/nais/wonderwall/pkg/metrics"
|
||||
mw "github.com/nais/wonderwall/pkg/middleware"
|
||||
@@ -37,7 +36,6 @@ type Standalone struct {
|
||||
Config *config.Config
|
||||
CookieOptions cookie.Options
|
||||
Crypter crypto.Crypter
|
||||
ErrorHandler errorhandler.Handler
|
||||
Ingresses *ingress.Ingresses
|
||||
OpenidConfig openidconfig.Config
|
||||
Redirect url.Redirect
|
||||
@@ -114,34 +112,22 @@ func (s *Standalone) GetCookieOptions(r *http.Request) cookie.Options {
|
||||
return s.CookieOptions.WithPath(path)
|
||||
}
|
||||
|
||||
func (s *Standalone) GetCrypter() crypto.Crypter {
|
||||
return s.Crypter
|
||||
}
|
||||
|
||||
func (s *Standalone) GetErrorHandler() errorhandler.Handler {
|
||||
return errorhandler.New(s)
|
||||
}
|
||||
|
||||
func (s *Standalone) GetIngresses() *ingress.Ingresses {
|
||||
return s.Ingresses
|
||||
}
|
||||
|
||||
func (s *Standalone) GetPath(r *http.Request) string {
|
||||
return GetPath(r, s)
|
||||
}
|
||||
|
||||
func (s *Standalone) GetRedirect() url.Redirect {
|
||||
return s.Redirect
|
||||
return GetPath(r, s.GetIngresses())
|
||||
}
|
||||
|
||||
func (s *Standalone) Login(w http.ResponseWriter, r *http.Request) {
|
||||
canonicalRedirect := s.GetRedirect().Canonical(r)
|
||||
canonicalRedirect := s.Redirect.Canonical(r)
|
||||
login, err := s.Client.Login(r)
|
||||
if err != nil {
|
||||
if errors.Is(err, openidclient.ErrInvalidSecurityLevel) || errors.Is(err, openidclient.ErrInvalidLocale) {
|
||||
s.GetErrorHandler().BadRequest(w, r, err)
|
||||
s.BadRequest(w, r, err)
|
||||
} else {
|
||||
s.GetErrorHandler().InternalError(w, r, err)
|
||||
s.InternalError(w, r, err)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -150,9 +136,9 @@ func (s *Standalone) Login(w http.ResponseWriter, r *http.Request) {
|
||||
opts := s.GetCookieOptions(r).
|
||||
WithExpiresIn(1 * time.Hour).
|
||||
WithSameSite(http.SameSiteNoneMode)
|
||||
err = login.SetCookie(w, opts, s.GetCrypter(), canonicalRedirect)
|
||||
err = login.SetCookie(w, opts, s.Crypter, canonicalRedirect)
|
||||
if err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err))
|
||||
s.InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -170,29 +156,29 @@ func (s *Standalone) LoginCallback(w http.ResponseWriter, r *http.Request) {
|
||||
cookie.Clear(w, cookie.Login, opts.WithSameSite(http.SameSiteNoneMode))
|
||||
cookie.Clear(w, cookie.LoginLegacy, opts.WithSameSite(http.SameSiteDefaultMode))
|
||||
|
||||
loginCookie, err := openid.GetLoginCookie(r, s.GetCrypter())
|
||||
loginCookie, err := openid.GetLoginCookie(r, s.Crypter)
|
||||
if err != nil {
|
||||
msg := "callback: fetching login cookie"
|
||||
if errors.Is(err, http.ErrNoCookie) {
|
||||
msg += ": fallback cookie not found (user might have blocked all cookies, or the callback route was accessed before the login route)"
|
||||
}
|
||||
s.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("%s: %w", msg, err))
|
||||
s.Unauthorized(w, r, fmt.Errorf("%s: %w", msg, err))
|
||||
return
|
||||
}
|
||||
|
||||
loginCallback, err := s.Client.LoginCallback(r, loginCookie)
|
||||
if err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, err)
|
||||
s.InternalError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := loginCallback.IdentityProviderError(); err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: %w", err))
|
||||
s.InternalError(w, r, fmt.Errorf("callback: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
if err := loginCallback.StateMismatchError(); err != nil {
|
||||
s.GetErrorHandler().Unauthorized(w, r, fmt.Errorf("callback: %w", err))
|
||||
s.Unauthorized(w, r, fmt.Errorf("callback: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -202,7 +188,7 @@ func (s *Standalone) LoginCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return retry.RetryableError(err)
|
||||
})
|
||||
if err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: redeeming tokens: %w", err))
|
||||
s.InternalError(w, r, fmt.Errorf("callback: redeeming tokens: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -210,17 +196,17 @@ func (s *Standalone) LoginCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
sess, err := s.SessionManager.Create(r, tokens, sessionLifetime)
|
||||
if err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
|
||||
s.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
err = sess.SetCookie(w, opts.WithExpiresIn(sessionLifetime), s.GetCrypter())
|
||||
err = sess.SetCookie(w, opts.WithExpiresIn(sessionLifetime), s.Crypter)
|
||||
if err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("callback: setting session cookie: %w", err))
|
||||
s.InternalError(w, r, fmt.Errorf("callback: setting session cookie: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
redirect := s.GetRedirect().Clean(r, loginCookie.Referer)
|
||||
redirect := s.Redirect.Clean(r, loginCookie.Referer)
|
||||
|
||||
fields := log.Fields{
|
||||
"redirect_to": redirect,
|
||||
@@ -234,28 +220,18 @@ func (s *Standalone) LoginCallback(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Standalone) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
opts := logoutOptions{
|
||||
GlobalLogout: true,
|
||||
}
|
||||
s.logout(w, r, opts)
|
||||
s.logout(w, r, true)
|
||||
}
|
||||
|
||||
func (s *Standalone) LogoutLocal(w http.ResponseWriter, r *http.Request) {
|
||||
opts := logoutOptions{
|
||||
GlobalLogout: false,
|
||||
}
|
||||
s.logout(w, r, opts)
|
||||
s.logout(w, r, false)
|
||||
}
|
||||
|
||||
type logoutOptions struct {
|
||||
GlobalLogout bool
|
||||
}
|
||||
|
||||
func (s *Standalone) logout(w http.ResponseWriter, r *http.Request, opts logoutOptions) {
|
||||
func (s *Standalone) logout(w http.ResponseWriter, r *http.Request, globalLogout bool) {
|
||||
logger := mw.LogEntryFrom(r)
|
||||
logout, err := s.Client.Logout(r)
|
||||
if err != nil {
|
||||
s.GetErrorHandler().InternalError(w, r, err)
|
||||
s.InternalError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -267,7 +243,7 @@ func (s *Standalone) logout(w http.ResponseWriter, r *http.Request, opts logoutO
|
||||
|
||||
err = s.SessionManager.Delete(r.Context(), sess)
|
||||
if err != nil && !errors.Is(err, session.ErrNotFound) {
|
||||
s.GetErrorHandler().InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
|
||||
s.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -277,7 +253,7 @@ func (s *Standalone) logout(w http.ResponseWriter, r *http.Request, opts logoutO
|
||||
|
||||
cookie.Clear(w, cookie.Session, s.GetCookieOptions(r))
|
||||
|
||||
if opts.GlobalLogout {
|
||||
if globalLogout {
|
||||
logger.Debug("logout: redirecting to identity provider for global/single-logout")
|
||||
metrics.ObserveLogout(metrics.LogoutOperationSelfInitiated)
|
||||
http.Redirect(w, r, logout.SingleLogoutURL(idToken), http.StatusTemporaryRedirect)
|
||||
|
||||
@@ -82,6 +82,32 @@ func NewSSOProxy(cfg *config.Config, crypter crypto.Crypter) (*SSOProxy, error)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SSOProxy) GetAccessToken(r *http.Request) (string, error) {
|
||||
sess, err := s.SessionReader.Get(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return sess.AccessToken()
|
||||
}
|
||||
|
||||
func (s *SSOProxy) GetAutoLogin() *autologin.AutoLogin {
|
||||
return s.AutoLogin
|
||||
}
|
||||
|
||||
func (s *SSOProxy) GetIngresses() *ingress.Ingresses {
|
||||
return s.Ingresses
|
||||
}
|
||||
|
||||
func (s *SSOProxy) GetPath(r *http.Request) string {
|
||||
return GetPath(r, s.GetIngresses())
|
||||
}
|
||||
|
||||
func (s *SSOProxy) GetSSOServerURL() *urllib.URL {
|
||||
u := *s.SSOServerURL
|
||||
return &u
|
||||
}
|
||||
|
||||
func (s *SSOProxy) Login(w http.ResponseWriter, r *http.Request) {
|
||||
logger := logentry.LogEntryFrom(r)
|
||||
|
||||
@@ -99,7 +125,7 @@ func (s *SSOProxy) Login(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
target.RawQuery = reqQuery.Encode()
|
||||
|
||||
canonicalRedirect := s.GetRedirect().Canonical(r)
|
||||
canonicalRedirect := s.Redirect.Canonical(r)
|
||||
ssoServerLoginURL := url.Login(target, canonicalRedirect)
|
||||
|
||||
logger.WithFields(log.Fields{
|
||||
@@ -144,33 +170,3 @@ func (s *SSOProxy) SessionRefresh(w http.ResponseWriter, r *http.Request) {
|
||||
func (s *SSOProxy) Wildcard(w http.ResponseWriter, r *http.Request) {
|
||||
s.UpstreamProxy.Handler(s, w, r)
|
||||
}
|
||||
|
||||
func (s *SSOProxy) GetAccessToken(r *http.Request) (string, error) {
|
||||
sess, err := s.SessionReader.Get(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return sess.AccessToken()
|
||||
}
|
||||
|
||||
func (s *SSOProxy) GetAutoLogin() *autologin.AutoLogin {
|
||||
return s.AutoLogin
|
||||
}
|
||||
|
||||
func (s *SSOProxy) GetIngresses() *ingress.Ingresses {
|
||||
return s.Ingresses
|
||||
}
|
||||
|
||||
func (s *SSOProxy) GetPath(r *http.Request) string {
|
||||
return GetPath(r, s)
|
||||
}
|
||||
|
||||
func (s *SSOProxy) GetRedirect() url.Redirect {
|
||||
return s.Redirect
|
||||
}
|
||||
|
||||
func (s *SSOProxy) GetSSOServerURL() *urllib.URL {
|
||||
u := *s.SSOServerURL
|
||||
return &u
|
||||
}
|
||||
|
||||
@@ -7,16 +7,12 @@ import (
|
||||
mw "github.com/nais/wonderwall/pkg/middleware"
|
||||
)
|
||||
|
||||
type PathSource interface {
|
||||
GetIngresses() *ingress.Ingresses
|
||||
}
|
||||
|
||||
// GetPath returns the matching context path from the list of registered ingresses.
|
||||
// If none match, an empty string is returned.
|
||||
func GetPath(r *http.Request, src PathSource) string {
|
||||
func GetPath(r *http.Request, ingresses *ingress.Ingresses) string {
|
||||
path, ok := mw.PathFrom(r.Context())
|
||||
if !ok {
|
||||
path = src.GetIngresses().MatchingPath(r)
|
||||
path = ingresses.MatchingPath(r)
|
||||
}
|
||||
|
||||
return path
|
||||
|
||||
Reference in New Issue
Block a user