Files
wonderwall/pkg/middleware/logentry.go
2025-01-30 14:03:31 +01:00

110 lines
2.7 KiB
Go

package middleware
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5/middleware"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/trace"
httpinternal "github.com/nais/wonderwall/internal/http"
"github.com/nais/wonderwall/pkg/router/paths"
)
type logger struct {
Logger *log.Logger
Provider string
}
// Logger provides a middleware that logs requests and responses.
func Logger(provider string) logger {
return logger{
Logger: log.StandardLogger(),
Provider: provider,
}
}
// LogEntryFrom returns a log entry from the request context.
func LogEntryFrom(r *http.Request) *log.Entry {
ctx := r.Context()
entry, ok := ctx.Value(middleware.LogEntryCtxKey).(*logEntryAdapter)
if ok {
return entry.Logger
}
return log.NewEntry(log.StandardLogger()).
WithField("fallback_logger", true).
WithFields(httpinternal.Attributes(r)).
WithFields(traceFields(r))
}
func (l *logger) Handler(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
entry := l.newLogEntry(r)
ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor)
if !strings.HasSuffix(r.URL.Path, paths.Ping) {
t1 := time.Now()
defer func() {
entry.Write(ww.Status(), ww.BytesWritten(), ww.Header(), time.Since(t1), nil)
}()
}
next.ServeHTTP(ww, middleware.WithLogEntry(r, entry))
}
return http.HandlerFunc(fn)
}
func (l *logger) newLogEntry(r *http.Request) *logEntryAdapter {
return &logEntryAdapter{
requestFields: httpinternal.Attributes(r),
Logger: l.Logger.WithContext(r.Context()).
WithField("provider", l.Provider).
WithFields(traceFields(r)),
}
}
// logEntryAdapter implements [middleware.LogEntry]
type logEntryAdapter struct {
Logger *log.Entry
requestFields log.Fields
}
func (l *logEntryAdapter) Write(status, bytes int, _ http.Header, elapsed time.Duration, _ any) {
responseFields := log.Fields{
"response_status": status,
"response_bytes": bytes,
"response_elapsed_ms": float64(elapsed.Nanoseconds()) / 1000000.0, // in milliseconds, with fractional
}
l.Logger.WithFields(l.requestFields).
WithFields(responseFields).
Debugf("response: %d %s", status, http.StatusText(status))
}
func (l *logEntryAdapter) Panic(v interface{}, _ []byte) {
stacktrace := "#"
fields := log.Fields{
"stacktrace": stacktrace,
"error": fmt.Sprintf("%+v", v),
}
l.Logger = l.Logger.WithFields(fields)
}
func traceFields(r *http.Request) log.Fields {
fields := log.Fields{}
span := trace.SpanFromContext(r.Context())
if span.SpanContext().HasTraceID() {
fields["trace_id"] = span.SpanContext().TraceID().String()
} else {
fields["correlation_id"] = middleware.GetReqID(r.Context())
}
return fields
}