mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-11 19:06:43 +00:00
refactor: add correlation ID for error response logs
Co-Authored-By: Sindre Rødseth Hansen <sindre.rodseth.hansen@nav.no>
This commit is contained in:
@@ -2,6 +2,7 @@ package errorhandler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/nais/wonderwall/pkg/middleware/correlationid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/http"
|
||||
)
|
||||
@@ -11,19 +12,28 @@ var (
|
||||
InvalidLocaleError = errors.New("InvalidLocale")
|
||||
)
|
||||
|
||||
func respondError(w http.ResponseWriter, statusCode int, cause error) {
|
||||
log.Error(cause)
|
||||
func respondError(w http.ResponseWriter, r *http.Request, statusCode int, cause error) {
|
||||
id, ok := correlationid.GetFromContext(r.Context())
|
||||
if !ok {
|
||||
log.Warnf("no correlation id in context")
|
||||
}
|
||||
|
||||
logFields := log.Fields{
|
||||
"correlation_id": id,
|
||||
}
|
||||
|
||||
log.WithFields(logFields).Error(cause)
|
||||
w.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func InternalError(w http.ResponseWriter, cause error) {
|
||||
respondError(w, http.StatusInternalServerError, cause)
|
||||
func InternalError(w http.ResponseWriter, r *http.Request, cause error) {
|
||||
respondError(w, r, http.StatusInternalServerError, cause)
|
||||
}
|
||||
|
||||
func BadRequest(w http.ResponseWriter, cause error) {
|
||||
respondError(w, http.StatusBadRequest, cause)
|
||||
func BadRequest(w http.ResponseWriter, r *http.Request, cause error) {
|
||||
respondError(w, r, http.StatusBadRequest, cause)
|
||||
}
|
||||
|
||||
func Unauthorized(w http.ResponseWriter, cause error) {
|
||||
respondError(w, http.StatusUnauthorized, cause)
|
||||
func Unauthorized(w http.ResponseWriter, r *http.Request, cause error) {
|
||||
respondError(w, r, http.StatusUnauthorized, cause)
|
||||
}
|
||||
|
||||
24
pkg/middleware/correlationid/correlationid.go
Normal file
24
pkg/middleware/correlationid/correlationid.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package correlationid
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/google/uuid"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// contextKey is the type of contextKeys used for correlation IDs.
|
||||
type contextKey struct{}
|
||||
|
||||
func GetFromContext(ctx context.Context) (string, bool) {
|
||||
id, ok := ctx.Value(contextKey{}).(string)
|
||||
return id, ok
|
||||
}
|
||||
|
||||
func Handler(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, contextKey{}, uuid.New().String())
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
return http.HandlerFunc(fn)
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
// This code was originally written by Rene Zbinden and modified by Vladimir Konovalov.
|
||||
// Copied from https://github.com/766b/chi-prometheus and further adapted.
|
||||
|
||||
package middleware
|
||||
package prometheus
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
@@ -31,7 +31,7 @@ type Middleware struct {
|
||||
}
|
||||
|
||||
// NewMiddleware returns a new prometheus Middleware handler.
|
||||
func PrometheusMiddleware(name string, buckets ...float64) *Middleware {
|
||||
func NewMiddleware(name string, buckets ...float64) *Middleware {
|
||||
var m Middleware
|
||||
m.reqs = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/nais/wonderwall/pkg/middleware/correlationid"
|
||||
"github.com/nais/wonderwall/pkg/middleware/prometheus"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@@ -14,7 +16,6 @@ import (
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/cryptutil"
|
||||
"github.com/nais/wonderwall/pkg/errorhandler"
|
||||
"github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
"github.com/nais/wonderwall/pkg/token"
|
||||
|
||||
@@ -75,7 +76,7 @@ func (h *Handler) WithSecureCookie(enabled bool) *Handler {
|
||||
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
params, err := auth.GenerateLoginParameters()
|
||||
if err != nil {
|
||||
errorhandler.InternalError(w, fmt.Errorf("login: generating login parameters: %w", err))
|
||||
errorhandler.InternalError(w, r, fmt.Errorf("login: generating login parameters: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -84,9 +85,9 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
cause := fmt.Errorf("login: creating login URL: %w", err)
|
||||
|
||||
if errors.Is(err, errorhandler.InvalidSecurityLevelError) || errors.Is(err, errorhandler.InvalidLocaleError) {
|
||||
errorhandler.BadRequest(w, cause)
|
||||
errorhandler.BadRequest(w, r, cause)
|
||||
} else {
|
||||
errorhandler.InternalError(w, cause)
|
||||
errorhandler.InternalError(w, r, cause)
|
||||
}
|
||||
|
||||
return
|
||||
@@ -99,7 +100,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
Referer: CanonicalRedirectURL(r),
|
||||
})
|
||||
if err != nil {
|
||||
errorhandler.InternalError(w, fmt.Errorf("login: setting cookie: %w", err))
|
||||
errorhandler.InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -109,7 +110,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
loginCookie, err := h.getLoginCookie(w, r)
|
||||
if err != nil {
|
||||
errorhandler.Unauthorized(w, fmt.Errorf("callback: fetching login cookie: %w", err))
|
||||
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: fetching login cookie: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -117,18 +118,18 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
if params.Get("error") != "" {
|
||||
oauthError := params.Get("error")
|
||||
oauthErrorDescription := params.Get("error_description")
|
||||
errorhandler.Unauthorized(w, fmt.Errorf("callback: error from identity provider: %s: %s", oauthError, oauthErrorDescription))
|
||||
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: error from identity provider: %s: %s", oauthError, oauthErrorDescription))
|
||||
return
|
||||
}
|
||||
|
||||
if params.Get("state") != loginCookie.State {
|
||||
errorhandler.Unauthorized(w, fmt.Errorf("callback: state parameter mismatch"))
|
||||
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: state parameter mismatch"))
|
||||
return
|
||||
}
|
||||
|
||||
assertion, err := h.Config.SignedJWTProfileAssertion(time.Second * 100)
|
||||
if err != nil {
|
||||
errorhandler.InternalError(w, fmt.Errorf("callback: creating client assertion: %w", err))
|
||||
errorhandler.InternalError(w, r, fmt.Errorf("callback: creating client assertion: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -140,13 +141,13 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
tokens, err := h.OauthConfig.Exchange(r.Context(), params.Get("code"), opts...)
|
||||
if err != nil {
|
||||
errorhandler.Unauthorized(w, fmt.Errorf("callback: exchanging code: %w", err))
|
||||
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: exchanging code: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := token.ParseIDToken(h.jwkSet, tokens)
|
||||
if err != nil {
|
||||
errorhandler.Unauthorized(w, fmt.Errorf("callback: parsing id_token: %w", err))
|
||||
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: parsing id_token: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -164,19 +165,19 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = idToken.Validate(validateOpts...)
|
||||
if err != nil {
|
||||
errorhandler.Unauthorized(w, fmt.Errorf("callback: validating id_token: %w", err))
|
||||
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: validating id_token: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
externalSessionID, ok := idToken.GetSID()
|
||||
if !ok {
|
||||
errorhandler.Unauthorized(w, fmt.Errorf("callback: missing required 'sid' claim in id_token"))
|
||||
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: missing required 'sid' claim in id_token"))
|
||||
return
|
||||
}
|
||||
|
||||
err = h.createSession(w, r, externalSessionID, tokens, idToken)
|
||||
if err != nil {
|
||||
errorhandler.InternalError(w, fmt.Errorf("callback: creating session: %w", err))
|
||||
errorhandler.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -241,7 +242,7 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
u, err := url.Parse(h.Config.WellKnown.EndSessionEndpoint)
|
||||
if err != nil {
|
||||
errorhandler.InternalError(w, fmt.Errorf("logout: parsing end session endpoint: %w", err))
|
||||
errorhandler.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -252,7 +253,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
idToken = sess.IDToken
|
||||
err = h.destroySession(w, r, h.localSessionID(sess.ExternalSessionID))
|
||||
if err != nil {
|
||||
errorhandler.InternalError(w, fmt.Errorf("logout: destroying session: %w", err))
|
||||
errorhandler.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -277,7 +278,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
|
||||
sid := params.Get("sid")
|
||||
|
||||
if len(sid) == 0 {
|
||||
errorhandler.BadRequest(w, fmt.Errorf("front-channel logout: sid not set in query parameter"))
|
||||
errorhandler.BadRequest(w, r, fmt.Errorf("front-channel logout: sid not set in query parameter"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -295,11 +296,12 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
func New(handler *Handler, prefixes []string) chi.Router {
|
||||
r := chi.NewRouter()
|
||||
mm := middleware.PrometheusMiddleware("wonderwall")
|
||||
prometheusMiddleware := prometheus.NewMiddleware("wonderwall")
|
||||
|
||||
for _, prefix := range prefixes {
|
||||
r.Route(prefix+"/oauth2", func(r chi.Router) {
|
||||
r.Use(mm.Handler())
|
||||
r.Use(prometheusMiddleware.Handler())
|
||||
r.Use(correlationid.Handler)
|
||||
r.Use(chi_middleware.NoCache)
|
||||
r.Get("/login", handler.Login)
|
||||
r.Get("/callback", handler.Callback)
|
||||
|
||||
Reference in New Issue
Block a user