From 569855cef2d455b0c420cb6fa00c2604d5cb3eff Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Mon, 4 Oct 2021 18:45:40 +0200 Subject: [PATCH] refactor: minor cleanups for middleware --- pkg/errorhandler/errorhandler.go | 4 ++-- .../{correlationid => }/correlationid.go | 6 ++--- pkg/middleware/{prometheus => }/prometheus.go | 22 +++++++------------ pkg/router/router.go | 9 ++++---- 4 files changed, 17 insertions(+), 24 deletions(-) rename pkg/middleware/{correlationid => }/correlationid.go (75%) rename pkg/middleware/{prometheus => }/prometheus.go (78%) diff --git a/pkg/errorhandler/errorhandler.go b/pkg/errorhandler/errorhandler.go index 48bb545..469db5c 100644 --- a/pkg/errorhandler/errorhandler.go +++ b/pkg/errorhandler/errorhandler.go @@ -2,7 +2,7 @@ package errorhandler import ( "errors" - "github.com/nais/wonderwall/pkg/middleware/correlationid" + "github.com/nais/wonderwall/pkg/middleware" log "github.com/sirupsen/logrus" "net/http" ) @@ -13,7 +13,7 @@ var ( ) func respondError(w http.ResponseWriter, r *http.Request, statusCode int, cause error) { - id, ok := correlationid.GetFromContext(r.Context()) + id, ok := middleware.GetCorrelationID(r.Context()) if !ok { log.Warnf("no correlation id in context") } diff --git a/pkg/middleware/correlationid/correlationid.go b/pkg/middleware/correlationid.go similarity index 75% rename from pkg/middleware/correlationid/correlationid.go rename to pkg/middleware/correlationid.go index c8a0622..721eab5 100644 --- a/pkg/middleware/correlationid/correlationid.go +++ b/pkg/middleware/correlationid.go @@ -1,4 +1,4 @@ -package correlationid +package middleware import ( "context" @@ -9,12 +9,12 @@ import ( // contextKey is the type of contextKeys used for correlation IDs. type contextKey struct{} -func GetFromContext(ctx context.Context) (string, bool) { +func GetCorrelationID(ctx context.Context) (string, bool) { id, ok := ctx.Value(contextKey{}).(string) return id, ok } -func Handler(next http.Handler) http.Handler { +func CorrelationIDHandler(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctx = context.WithValue(ctx, contextKey{}, uuid.New().String()) diff --git a/pkg/middleware/prometheus/prometheus.go b/pkg/middleware/prometheus.go similarity index 78% rename from pkg/middleware/prometheus/prometheus.go rename to pkg/middleware/prometheus.go index bb2aa3b..b8ae8c5 100644 --- a/pkg/middleware/prometheus/prometheus.go +++ b/pkg/middleware/prometheus.go @@ -1,7 +1,7 @@ // This code was originally written by Rene Zbinden and modified by Vladimir Konovalov. // Copied from https://github.com/766b/chi-prometheus and further adapted. -package prometheus +package middleware import ( "net/http" @@ -21,18 +21,16 @@ const ( latencyName = "request_duration_seconds" ) -type middleware func(http.Handler) http.Handler - -// Middleware is a handler that exposes prometheus metrics for the number of requests, +// PrometheusMiddleware is a handler that exposes prometheus metrics for the number of requests, // the latency and the response size, partitioned by status code, method and HTTP path. -type Middleware struct { +type PrometheusMiddleware struct { reqs *prometheus.CounterVec latency *prometheus.HistogramVec } -// NewMiddleware returns a new prometheus Middleware handler. -func NewMiddleware(name string, buckets ...float64) *Middleware { - var m Middleware +// NewPrometheusMiddleware returns a new PrometheusMiddleware handler. +func NewPrometheusMiddleware(name string, buckets ...float64) *PrometheusMiddleware { + var m PrometheusMiddleware m.reqs = prometheus.NewCounterVec( prometheus.CounterOpts{ Name: reqsName, @@ -60,7 +58,7 @@ func NewMiddleware(name string, buckets ...float64) *Middleware { return &m } -func (m *Middleware) Initialize(path, method string, code int) { +func (m *PrometheusMiddleware) Initialize(path, method string, code int) { m.reqs.WithLabelValues( strconv.Itoa(code), method, @@ -68,11 +66,7 @@ func (m *Middleware) Initialize(path, method string, code int) { ) } -func (m *Middleware) Handler() middleware { - return m.handler -} - -func (m Middleware) handler(next http.Handler) http.Handler { +func (m *PrometheusMiddleware) Handler(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { start := time.Now() ww := chi_middleware.NewWrapResponseWriter(w, r.ProtoMajor) diff --git a/pkg/router/router.go b/pkg/router/router.go index 0ee4885..ad6379f 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "github.com/nais/wonderwall/pkg/middleware/correlationid" - "github.com/nais/wonderwall/pkg/middleware/prometheus" "io" "net/http" "net/url" @@ -16,6 +14,7 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cryptutil" "github.com/nais/wonderwall/pkg/errorhandler" + "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/session" "github.com/nais/wonderwall/pkg/token" @@ -296,12 +295,12 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { func New(handler *Handler, prefixes []string) chi.Router { r := chi.NewRouter() - prometheusMiddleware := prometheus.NewMiddleware("wonderwall") + prometheusMiddleware := middleware.NewPrometheusMiddleware("wonderwall") for _, prefix := range prefixes { r.Route(prefix+"/oauth2", func(r chi.Router) { - r.Use(prometheusMiddleware.Handler()) - r.Use(correlationid.Handler) + r.Use(prometheusMiddleware.Handler) + r.Use(middleware.CorrelationIDHandler) r.Use(chi_middleware.NoCache) r.Get("/login", handler.Login) r.Get("/callback", handler.Callback)