package middleware import ( "bytes" "io" "io/ioutil" "net/http" "time" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/httplog" "github.com/rs/zerolog" ) // LogEntryHandler is copied verbatim from httplog package to replace with our own requestLogger implementation. func LogEntryHandler(logger zerolog.Logger) func(next http.Handler) http.Handler { var f middleware.LogFormatter = &requestLogger{logger} return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { entry := f.NewLogEntry(r) ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) buf := newLimitBuffer(512) ww.Tee(buf) t1 := time.Now() defer func() { var respBody []byte if ww.Status() >= 400 { respBody, _ = ioutil.ReadAll(buf) } entry.Write(ww.Status(), ww.BytesWritten(), ww.Header(), time.Since(t1), respBody) }() next.ServeHTTP(ww, middleware.WithLogEntry(r, entry)) } return http.HandlerFunc(fn) } } type requestLogger struct { Logger zerolog.Logger } func (l *requestLogger) NewLogEntry(r *http.Request) middleware.LogEntry { entry := &httplog.RequestLoggerEntry{} entry.Logger = l.Logger.With().Fields(requestLogFields(r)).Logger() return entry } func requestLogFields(r *http.Request) map[string]interface{} { requestFields := map[string]interface{}{ "requestMethod": r.Method, "requestPath": r.URL.Path, "proto": r.Proto, } if reqID := middleware.GetReqID(r.Context()); reqID != "" { requestFields["requestID"] = reqID } return map[string]interface{}{ "httpRequest": requestFields, } } // limitBuffer is used to pipe response body information from the // response writer to a certain limit amount. The idea is to read // a portion of the response body such as an error response so we // may log it. // // Copied verbatim from httplog package as it is unexported. type limitBuffer struct { *bytes.Buffer limit int } func newLimitBuffer(size int) io.ReadWriter { return limitBuffer{ Buffer: bytes.NewBuffer(make([]byte, 0, size)), limit: size, } } func (b limitBuffer) Write(p []byte) (n int, err error) { if b.Buffer.Len() >= b.limit { return len(p), nil } limit := b.limit if len(p) < limit { limit = len(p) } return b.Buffer.Write(p[:limit]) } func (b limitBuffer) Read(p []byte) (n int, err error) { return b.Buffer.Read(p) }