mirror of
https://github.com/nais/wonderwall.git
synced 2026-02-14 17:49:54 +00:00
110 lines
2.7 KiB
Go
110 lines
2.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
log "github.com/sirupsen/logrus"
|
|
"go.opentelemetry.io/otel/trace"
|
|
|
|
httpinternal "github.com/nais/wonderwall/internal/http"
|
|
"github.com/nais/wonderwall/pkg/router/paths"
|
|
)
|
|
|
|
type logger struct {
|
|
Logger *log.Logger
|
|
Provider string
|
|
}
|
|
|
|
// Logger provides a middleware that logs requests and responses.
|
|
func Logger(provider string) logger {
|
|
return logger{
|
|
Logger: log.StandardLogger(),
|
|
Provider: provider,
|
|
}
|
|
}
|
|
|
|
// LogEntryFrom returns a log entry from the request context.
|
|
func LogEntryFrom(r *http.Request) *log.Entry {
|
|
ctx := r.Context()
|
|
entry, ok := ctx.Value(middleware.LogEntryCtxKey).(*logEntryAdapter)
|
|
if ok {
|
|
return entry.Logger
|
|
}
|
|
|
|
return log.NewEntry(log.StandardLogger()).
|
|
WithField("fallback_logger", true).
|
|
WithFields(httpinternal.Attributes(r)).
|
|
WithFields(traceFields(r))
|
|
}
|
|
|
|
func (l *logger) Handler(next http.Handler) http.Handler {
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
entry := l.newLogEntry(r)
|
|
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
|
|
|
|
if !strings.HasSuffix(r.URL.Path, paths.Ping) {
|
|
t1 := time.Now()
|
|
defer func() {
|
|
entry.Write(ww.Status(), ww.BytesWritten(), ww.Header(), time.Since(t1), nil)
|
|
}()
|
|
}
|
|
|
|
next.ServeHTTP(ww, middleware.WithLogEntry(r, entry))
|
|
}
|
|
return http.HandlerFunc(fn)
|
|
}
|
|
|
|
func (l *logger) newLogEntry(r *http.Request) *logEntryAdapter {
|
|
return &logEntryAdapter{
|
|
requestFields: httpinternal.Attributes(r),
|
|
Logger: l.Logger.WithContext(r.Context()).
|
|
WithField("provider", l.Provider).
|
|
WithFields(traceFields(r)),
|
|
}
|
|
}
|
|
|
|
// logEntryAdapter implements [middleware.LogEntry]
|
|
type logEntryAdapter struct {
|
|
Logger *log.Entry
|
|
requestFields log.Fields
|
|
}
|
|
|
|
func (l *logEntryAdapter) Write(status, bytes int, _ http.Header, elapsed time.Duration, _ any) {
|
|
responseFields := log.Fields{
|
|
"response_status": status,
|
|
"response_bytes": bytes,
|
|
"response_elapsed_ms": float64(elapsed.Nanoseconds()) / 1000000.0, // in milliseconds, with fractional
|
|
}
|
|
|
|
l.Logger.WithFields(l.requestFields).
|
|
WithFields(responseFields).
|
|
Debugf("response: %d %s", status, http.StatusText(status))
|
|
}
|
|
|
|
func (l *logEntryAdapter) Panic(v interface{}, _ []byte) {
|
|
stacktrace := "#"
|
|
|
|
fields := log.Fields{
|
|
"stacktrace": stacktrace,
|
|
"error": fmt.Sprintf("%+v", v),
|
|
}
|
|
|
|
l.Logger = l.Logger.WithFields(fields)
|
|
}
|
|
|
|
func traceFields(r *http.Request) log.Fields {
|
|
fields := log.Fields{}
|
|
span := trace.SpanFromContext(r.Context())
|
|
if span.SpanContext().HasTraceID() {
|
|
fields["trace_id"] = span.SpanContext().TraceID().String()
|
|
} else {
|
|
fields["correlation_id"] = middleware.GetReqID(r.Context())
|
|
}
|
|
|
|
return fields
|
|
}
|