refactor: add correlation ID for error response logs

Co-Authored-By: Sindre Rødseth Hansen <sindre.rodseth.hansen@nav.no>
This commit is contained in:
Trong Huu Nguyen
2021-10-04 14:36:41 +02:00
parent ce8d8c6460
commit 788ef1278a
4 changed files with 65 additions and 29 deletions

View File

@@ -2,6 +2,7 @@ package errorhandler
import (
"errors"
"github.com/nais/wonderwall/pkg/middleware/correlationid"
log "github.com/sirupsen/logrus"
"net/http"
)
@@ -11,19 +12,28 @@ var (
InvalidLocaleError = errors.New("InvalidLocale")
)
func respondError(w http.ResponseWriter, statusCode int, cause error) {
log.Error(cause)
func respondError(w http.ResponseWriter, r *http.Request, statusCode int, cause error) {
id, ok := correlationid.GetFromContext(r.Context())
if !ok {
log.Warnf("no correlation id in context")
}
logFields := log.Fields{
"correlation_id": id,
}
log.WithFields(logFields).Error(cause)
w.WriteHeader(statusCode)
}
func InternalError(w http.ResponseWriter, cause error) {
respondError(w, http.StatusInternalServerError, cause)
func InternalError(w http.ResponseWriter, r *http.Request, cause error) {
respondError(w, r, http.StatusInternalServerError, cause)
}
func BadRequest(w http.ResponseWriter, cause error) {
respondError(w, http.StatusBadRequest, cause)
func BadRequest(w http.ResponseWriter, r *http.Request, cause error) {
respondError(w, r, http.StatusBadRequest, cause)
}
func Unauthorized(w http.ResponseWriter, cause error) {
respondError(w, http.StatusUnauthorized, cause)
func Unauthorized(w http.ResponseWriter, r *http.Request, cause error) {
respondError(w, r, http.StatusUnauthorized, cause)
}

View File

@@ -0,0 +1,24 @@
package correlationid
import (
"context"
"github.com/google/uuid"
"net/http"
)
// contextKey is the type of contextKeys used for correlation IDs.
type contextKey struct{}
func GetFromContext(ctx context.Context) (string, bool) {
id, ok := ctx.Value(contextKey{}).(string)
return id, ok
}
func Handler(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = context.WithValue(ctx, contextKey{}, uuid.New().String())
next.ServeHTTP(w, r.WithContext(ctx))
}
return http.HandlerFunc(fn)
}

View File

@@ -1,7 +1,7 @@
// This code was originally written by Rene Zbinden and modified by Vladimir Konovalov.
// Copied from https://github.com/766b/chi-prometheus and further adapted.
package middleware
package prometheus
import (
"net/http"
@@ -31,7 +31,7 @@ type Middleware struct {
}
// NewMiddleware returns a new prometheus Middleware handler.
func PrometheusMiddleware(name string, buckets ...float64) *Middleware {
func NewMiddleware(name string, buckets ...float64) *Middleware {
var m Middleware
m.reqs = prometheus.NewCounterVec(
prometheus.CounterOpts{

View File

@@ -4,6 +4,8 @@ import (
"context"
"errors"
"fmt"
"github.com/nais/wonderwall/pkg/middleware/correlationid"
"github.com/nais/wonderwall/pkg/middleware/prometheus"
"io"
"net/http"
"net/url"
@@ -14,7 +16,6 @@ import (
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/cryptutil"
"github.com/nais/wonderwall/pkg/errorhandler"
"github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/session"
"github.com/nais/wonderwall/pkg/token"
@@ -75,7 +76,7 @@ func (h *Handler) WithSecureCookie(enabled bool) *Handler {
func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
params, err := auth.GenerateLoginParameters()
if err != nil {
errorhandler.InternalError(w, fmt.Errorf("login: generating login parameters: %w", err))
errorhandler.InternalError(w, r, fmt.Errorf("login: generating login parameters: %w", err))
return
}
@@ -84,9 +85,9 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
cause := fmt.Errorf("login: creating login URL: %w", err)
if errors.Is(err, errorhandler.InvalidSecurityLevelError) || errors.Is(err, errorhandler.InvalidLocaleError) {
errorhandler.BadRequest(w, cause)
errorhandler.BadRequest(w, r, cause)
} else {
errorhandler.InternalError(w, cause)
errorhandler.InternalError(w, r, cause)
}
return
@@ -99,7 +100,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
Referer: CanonicalRedirectURL(r),
})
if err != nil {
errorhandler.InternalError(w, fmt.Errorf("login: setting cookie: %w", err))
errorhandler.InternalError(w, r, fmt.Errorf("login: setting cookie: %w", err))
return
}
@@ -109,7 +110,7 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
loginCookie, err := h.getLoginCookie(w, r)
if err != nil {
errorhandler.Unauthorized(w, fmt.Errorf("callback: fetching login cookie: %w", err))
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: fetching login cookie: %w", err))
return
}
@@ -117,18 +118,18 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
if params.Get("error") != "" {
oauthError := params.Get("error")
oauthErrorDescription := params.Get("error_description")
errorhandler.Unauthorized(w, fmt.Errorf("callback: error from identity provider: %s: %s", oauthError, oauthErrorDescription))
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: error from identity provider: %s: %s", oauthError, oauthErrorDescription))
return
}
if params.Get("state") != loginCookie.State {
errorhandler.Unauthorized(w, fmt.Errorf("callback: state parameter mismatch"))
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: state parameter mismatch"))
return
}
assertion, err := h.Config.SignedJWTProfileAssertion(time.Second * 100)
if err != nil {
errorhandler.InternalError(w, fmt.Errorf("callback: creating client assertion: %w", err))
errorhandler.InternalError(w, r, fmt.Errorf("callback: creating client assertion: %w", err))
return
}
@@ -140,13 +141,13 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
tokens, err := h.OauthConfig.Exchange(r.Context(), params.Get("code"), opts...)
if err != nil {
errorhandler.Unauthorized(w, fmt.Errorf("callback: exchanging code: %w", err))
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: exchanging code: %w", err))
return
}
idToken, err := token.ParseIDToken(h.jwkSet, tokens)
if err != nil {
errorhandler.Unauthorized(w, fmt.Errorf("callback: parsing id_token: %w", err))
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: parsing id_token: %w", err))
return
}
@@ -164,19 +165,19 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
err = idToken.Validate(validateOpts...)
if err != nil {
errorhandler.Unauthorized(w, fmt.Errorf("callback: validating id_token: %w", err))
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: validating id_token: %w", err))
return
}
externalSessionID, ok := idToken.GetSID()
if !ok {
errorhandler.Unauthorized(w, fmt.Errorf("callback: missing required 'sid' claim in id_token"))
errorhandler.Unauthorized(w, r, fmt.Errorf("callback: missing required 'sid' claim in id_token"))
return
}
err = h.createSession(w, r, externalSessionID, tokens, idToken)
if err != nil {
errorhandler.InternalError(w, fmt.Errorf("callback: creating session: %w", err))
errorhandler.InternalError(w, r, fmt.Errorf("callback: creating session: %w", err))
return
}
@@ -241,7 +242,7 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
u, err := url.Parse(h.Config.WellKnown.EndSessionEndpoint)
if err != nil {
errorhandler.InternalError(w, fmt.Errorf("logout: parsing end session endpoint: %w", err))
errorhandler.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err))
return
}
@@ -252,7 +253,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
idToken = sess.IDToken
err = h.destroySession(w, r, h.localSessionID(sess.ExternalSessionID))
if err != nil {
errorhandler.InternalError(w, fmt.Errorf("logout: destroying session: %w", err))
errorhandler.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
return
}
}
@@ -277,7 +278,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
sid := params.Get("sid")
if len(sid) == 0 {
errorhandler.BadRequest(w, fmt.Errorf("front-channel logout: sid not set in query parameter"))
errorhandler.BadRequest(w, r, fmt.Errorf("front-channel logout: sid not set in query parameter"))
return
}
@@ -295,11 +296,12 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
func New(handler *Handler, prefixes []string) chi.Router {
r := chi.NewRouter()
mm := middleware.PrometheusMiddleware("wonderwall")
prometheusMiddleware := prometheus.NewMiddleware("wonderwall")
for _, prefix := range prefixes {
r.Route(prefix+"/oauth2", func(r chi.Router) {
r.Use(mm.Handler())
r.Use(prometheusMiddleware.Handler())
r.Use(correlationid.Handler)
r.Use(chi_middleware.NoCache)
r.Get("/login", handler.Login)
r.Get("/callback", handler.Callback)