diff --git a/pkg/errorhandler/errorhandler.go b/pkg/errorhandler/errorhandler.go index a9fbcb5..48bb545 100644 --- a/pkg/errorhandler/errorhandler.go +++ b/pkg/errorhandler/errorhandler.go @@ -2,6 +2,7 @@ package errorhandler import ( "errors" + "github.com/nais/wonderwall/pkg/middleware/correlationid" log "github.com/sirupsen/logrus" "net/http" ) @@ -11,19 +12,28 @@ var ( InvalidLocaleError = errors.New("InvalidLocale") ) -func respondError(w http.ResponseWriter, statusCode int, cause error) { - log.Error(cause) +func respondError(w http.ResponseWriter, r *http.Request, statusCode int, cause error) { + id, ok := correlationid.GetFromContext(r.Context()) + if !ok { + log.Warnf("no correlation id in context") + } + + logFields := log.Fields{ + "correlation_id": id, + } + + log.WithFields(logFields).Error(cause) w.WriteHeader(statusCode) } -func InternalError(w http.ResponseWriter, cause error) { - respondError(w, http.StatusInternalServerError, cause) +func InternalError(w http.ResponseWriter, r *http.Request, cause error) { + respondError(w, r, http.StatusInternalServerError, cause) } -func BadRequest(w http.ResponseWriter, cause error) { - respondError(w, http.StatusBadRequest, cause) +func BadRequest(w http.ResponseWriter, r *http.Request, cause error) { + respondError(w, r, http.StatusBadRequest, cause) } -func Unauthorized(w http.ResponseWriter, cause error) { - respondError(w, http.StatusUnauthorized, cause) +func Unauthorized(w http.ResponseWriter, r *http.Request, cause error) { + respondError(w, r, http.StatusUnauthorized, cause) } diff --git a/pkg/middleware/correlationid/correlationid.go b/pkg/middleware/correlationid/correlationid.go new file mode 100644 index 0000000..c8a0622 --- /dev/null +++ b/pkg/middleware/correlationid/correlationid.go @@ -0,0 +1,24 @@ +package correlationid + +import ( + "context" + "github.com/google/uuid" + "net/http" +) + +// contextKey is the type of contextKeys used for correlation IDs. +type contextKey struct{} + +func GetFromContext(ctx context.Context) (string, bool) { + id, ok := ctx.Value(contextKey{}).(string) + return id, ok +} + +func Handler(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + ctx = context.WithValue(ctx, contextKey{}, uuid.New().String()) + next.ServeHTTP(w, r.WithContext(ctx)) + } + return http.HandlerFunc(fn) +} diff --git a/pkg/middleware/prometheus.go b/pkg/middleware/prometheus/prometheus.go similarity index 96% rename from pkg/middleware/prometheus.go rename to pkg/middleware/prometheus/prometheus.go index acf9257..bb2aa3b 100644 --- a/pkg/middleware/prometheus.go +++ b/pkg/middleware/prometheus/prometheus.go @@ -1,7 +1,7 @@ // This code was originally written by Rene Zbinden and modified by Vladimir Konovalov. // Copied from https://github.com/766b/chi-prometheus and further adapted. -package middleware +package prometheus import ( "net/http" @@ -31,7 +31,7 @@ type Middleware struct { } // NewMiddleware returns a new prometheus Middleware handler. -func PrometheusMiddleware(name string, buckets ...float64) *Middleware { +func NewMiddleware(name string, buckets ...float64) *Middleware { var m Middleware m.reqs = prometheus.NewCounterVec( prometheus.CounterOpts{ diff --git a/pkg/router/router.go b/pkg/router/router.go index 6c6c624..0ee4885 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "github.com/nais/wonderwall/pkg/middleware/correlationid" + "github.com/nais/wonderwall/pkg/middleware/prometheus" "io" "net/http" "net/url" @@ -14,7 +16,6 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cryptutil" "github.com/nais/wonderwall/pkg/errorhandler" - "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/session" "github.com/nais/wonderwall/pkg/token" @@ -75,7 +76,7 @@ func (h *Handler) WithSecureCookie(enabled bool) *Handler { func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { params, err := auth.GenerateLoginParameters() if err != nil { - errorhandler.InternalError(w, fmt.Errorf("login: generating login parameters: %w", err)) + errorhandler.InternalError(w, r, fmt.Errorf("login: generating login parameters: %w", err)) return } @@ -84,9 +85,9 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { cause := fmt.Errorf("login: creating login URL: %w", err) if errors.Is(err, errorhandler.InvalidSecurityLevelError) || errors.Is(err, errorhandler.InvalidLocaleError) { - errorhandler.BadRequest(w, cause) + errorhandler.BadRequest(w, r, cause) } else { - errorhandler.InternalError(w, cause) + errorhandler.InternalError(w, r, cause) } return @@ -99,7 +100,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { Referer: CanonicalRedirectURL(r), }) if err != nil { - errorhandler.InternalError(w, fmt.Errorf("login: setting cookie: %w", err)) + errorhandler.InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err)) return } @@ -109,7 +110,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { loginCookie, err := h.getLoginCookie(w, r) if err != nil { - errorhandler.Unauthorized(w, fmt.Errorf("callback: fetching login cookie: %w", err)) + errorhandler.Unauthorized(w, r, fmt.Errorf("callback: fetching login cookie: %w", err)) return } @@ -117,18 +118,18 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { if params.Get("error") != "" { oauthError := params.Get("error") oauthErrorDescription := params.Get("error_description") - errorhandler.Unauthorized(w, fmt.Errorf("callback: error from identity provider: %s: %s", oauthError, oauthErrorDescription)) + errorhandler.Unauthorized(w, r, fmt.Errorf("callback: error from identity provider: %s: %s", oauthError, oauthErrorDescription)) return } if params.Get("state") != loginCookie.State { - errorhandler.Unauthorized(w, fmt.Errorf("callback: state parameter mismatch")) + errorhandler.Unauthorized(w, r, fmt.Errorf("callback: state parameter mismatch")) return } assertion, err := h.Config.SignedJWTProfileAssertion(time.Second * 100) if err != nil { - errorhandler.InternalError(w, fmt.Errorf("callback: creating client assertion: %w", err)) + errorhandler.InternalError(w, r, fmt.Errorf("callback: creating client assertion: %w", err)) return } @@ -140,13 +141,13 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { tokens, err := h.OauthConfig.Exchange(r.Context(), params.Get("code"), opts...) if err != nil { - errorhandler.Unauthorized(w, fmt.Errorf("callback: exchanging code: %w", err)) + errorhandler.Unauthorized(w, r, fmt.Errorf("callback: exchanging code: %w", err)) return } idToken, err := token.ParseIDToken(h.jwkSet, tokens) if err != nil { - errorhandler.Unauthorized(w, fmt.Errorf("callback: parsing id_token: %w", err)) + errorhandler.Unauthorized(w, r, fmt.Errorf("callback: parsing id_token: %w", err)) return } @@ -164,19 +165,19 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { err = idToken.Validate(validateOpts...) if err != nil { - errorhandler.Unauthorized(w, fmt.Errorf("callback: validating id_token: %w", err)) + errorhandler.Unauthorized(w, r, fmt.Errorf("callback: validating id_token: %w", err)) return } externalSessionID, ok := idToken.GetSID() if !ok { - errorhandler.Unauthorized(w, fmt.Errorf("callback: missing required 'sid' claim in id_token")) + errorhandler.Unauthorized(w, r, fmt.Errorf("callback: missing required 'sid' claim in id_token")) return } err = h.createSession(w, r, externalSessionID, tokens, idToken) if err != nil { - errorhandler.InternalError(w, fmt.Errorf("callback: creating session: %w", err)) + errorhandler.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err)) return } @@ -241,7 +242,7 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { u, err := url.Parse(h.Config.WellKnown.EndSessionEndpoint) if err != nil { - errorhandler.InternalError(w, fmt.Errorf("logout: parsing end session endpoint: %w", err)) + errorhandler.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err)) return } @@ -252,7 +253,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { idToken = sess.IDToken err = h.destroySession(w, r, h.localSessionID(sess.ExternalSessionID)) if err != nil { - errorhandler.InternalError(w, fmt.Errorf("logout: destroying session: %w", err)) + errorhandler.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err)) return } } @@ -277,7 +278,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { sid := params.Get("sid") if len(sid) == 0 { - errorhandler.BadRequest(w, fmt.Errorf("front-channel logout: sid not set in query parameter")) + errorhandler.BadRequest(w, r, fmt.Errorf("front-channel logout: sid not set in query parameter")) return } @@ -295,11 +296,12 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { func New(handler *Handler, prefixes []string) chi.Router { r := chi.NewRouter() - mm := middleware.PrometheusMiddleware("wonderwall") + prometheusMiddleware := prometheus.NewMiddleware("wonderwall") for _, prefix := range prefixes { r.Route(prefix+"/oauth2", func(r chi.Router) { - r.Use(mm.Handler()) + r.Use(prometheusMiddleware.Handler()) + r.Use(correlationid.Handler) r.Use(chi_middleware.NoCache) r.Get("/login", handler.Login) r.Get("/callback", handler.Callback)