From 66d734358d92fbf63a8b8cbfba12f1ba16c58eb2 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Fri, 8 May 2026 07:49:17 +0200 Subject: [PATCH] refactor: use errgroup for coordinated server lifecycle Move metrics and probe listeners from main.go into server.go and manage all servers with errgroup. This replaces log.Fatalf goroutines with proper error propagation and ties all server lifetimes together so a failure in any listener triggers graceful shutdown of the others. --- cmd/wonderwall/main.go | 42 +----------- go.mod | 2 +- pkg/metrics/metrics.go | 17 +---- pkg/server/server.go | 144 +++++++++++++++++++++++++++++++---------- 4 files changed, 113 insertions(+), 92 deletions(-) diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index 9a77791..16fbf80 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -3,8 +3,6 @@ package main import ( "context" "fmt" - "net/http" - "net/http/pprof" "github.com/KimMachineGun/automemlimit/memlimit" "github.com/nais/wonderwall/internal/crypto" @@ -12,7 +10,6 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" "github.com/nais/wonderwall/pkg/handler" - "github.com/nais/wonderwall/pkg/metrics" openidconfig "github.com/nais/wonderwall/pkg/openid/config" "github.com/nais/wonderwall/pkg/openid/provider" "github.com/nais/wonderwall/pkg/router" @@ -84,44 +81,7 @@ func run() error { r := router.New(src, cfg) - if cfg.MetricsBindAddress != "" { - go func() { - log.Debugf("metrics: listening on %s", cfg.MetricsBindAddress) - err := metrics.Handle(cfg.MetricsBindAddress, cfg.OpenID.Provider) - if err != nil { - log.Fatalf("fatal: metrics server error: %s", err) - } - }() - } - - if cfg.ProbeBindAddress != "" { - go func() { - mux := http.NewServeMux() - healthz := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("ok")) - }) - mux.HandleFunc("/", healthz) - mux.HandleFunc("/healthz", healthz) - - if cfg.PprofEnabled { - mux.HandleFunc("/debug/pprof/", pprof.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprof.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprof.Trace) - log.Infof("pprof: enabled on %s/debug/pprof/", cfg.ProbeBindAddress) - } - - log.Debugf("probe: listening on %s", cfg.ProbeBindAddress) - err := http.ListenAndServe(cfg.ProbeBindAddress, mux) - if err != nil { - log.Fatalf("fatal: probe server error: %s", err) - } - }() - } - - return server.Start(cfg, r) + return server.Start(ctx, cfg, r) } func standalone(ctx context.Context, cfg *config.Config, crypt crypto.Crypter) (*handler.Standalone, error) { diff --git a/go.mod b/go.mod index a8ebf97..d0fff15 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,7 @@ require ( go.opentelemetry.io/otel/trace v1.43.0 golang.org/x/crypto v0.50.0 golang.org/x/oauth2 v0.36.0 + golang.org/x/sync v0.20.0 ) require ( @@ -102,7 +103,6 @@ require ( golang.org/x/exp/typeparams v0.0.0-20231108232855-2478ac86f678 // indirect golang.org/x/mod v0.34.0 // indirect golang.org/x/net v0.52.0 // indirect - golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.43.0 // indirect golang.org/x/telemetry v0.0.0-20260311193753-579e4da9a98c // indirect golang.org/x/text v0.36.0 // indirect diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 82ad634..0a59974 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -1,15 +1,11 @@ package metrics import ( - "net/http" "net/url" "strings" "time" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" - - "github.com/nais/wonderwall/pkg/config" ) const ( @@ -127,17 +123,8 @@ func InitLabels() { Logins.With(prometheus.Labels{LabelAmr: "", LabelRedirect: ""}) } -func Handle(address string, provider config.Provider) error { - WithProvider(string(provider)) - Register(prometheus.DefaultRegisterer) - InitLabels() - - handler := promhttp.Handler() - return http.ListenAndServe(address, handler) -} - -func Register(registry prometheus.Registerer) { - registry.MustRegister( +func Register() { + prometheus.DefaultRegisterer.MustRegister( RedisLatency, Logins, Logouts, diff --git a/pkg/server/server.go b/pkg/server/server.go index eef1d44..a8511ba 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -3,20 +3,27 @@ package server import ( "context" "errors" + "fmt" "net/http" - "os" + "net/http/pprof" "os/signal" "syscall" "time" "github.com/go-chi/chi/v5" + "github.com/prometheus/client_golang/prometheus/promhttp" log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" "github.com/nais/wonderwall/pkg/config" + "github.com/nais/wonderwall/pkg/metrics" ) -func Start(cfg *config.Config, r chi.Router) error { - server := http.Server{ +func Start(ctx context.Context, cfg *config.Config, r chi.Router) error { + ctx, stop := signal.NotifyContext(ctx, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + defer stop() + + mainServer := &http.Server{ Addr: cfg.BindAddress, Handler: r, ReadHeaderTimeout: 10 * time.Second, // Prevents slowloris attacks (connections held open without sending headers). @@ -25,42 +32,109 @@ func Start(cfg *config.Config, r chi.Router) error { // ReadTimeout/WriteTimeout intentionally omitted - a reverse proxy must support slow transfers. } - serverCtx, serverStopCtx := context.WithCancel(context.Background()) + servers := []*http.Server{mainServer} - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) - go func() { - s := <-sig - log.Infof("server: received %q; waiting for %s before starting graceful shutdown...", s, cfg.ShutdownWaitBeforePeriod) - time.Sleep(cfg.ShutdownWaitBeforePeriod) + g, gctx := errgroup.WithContext(ctx) - // the total terminationGracePeriodSeconds in Kubernetes starts immediately when SIGTERM is sent, so we need to subtract the wait-before period to exit before SIGKILL - shutdownTimeout := cfg.ShutdownGracefulPeriod - cfg.ShutdownWaitBeforePeriod - shutdownCtx, shutdownStopCtx := context.WithTimeout(serverCtx, shutdownTimeout) - - go func() { - <-shutdownCtx.Done() - if errors.Is(shutdownCtx.Err(), context.DeadlineExceeded) { - log.Fatalf("server: graceful shutdown timed out after %s; forcing exit.", shutdownTimeout) - } - }() - - log.Infof("server: starting graceful shutdown (will timeout after %s)...", shutdownTimeout) - err := server.Shutdown(shutdownCtx) - if err != nil { - log.Fatal(err) + g.Go(func() error { + log.Infof("server: listening on %s", cfg.BindAddress) + if err := mainServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("server: %w", err) } - shutdownStopCtx() - serverStopCtx() - }() + return nil + }) - log.Infof("server: listening on %s", cfg.BindAddress) - err := server.ListenAndServe() - if err != nil && !errors.Is(err, http.ErrServerClosed) { - return err + if cfg.MetricsBindAddress != "" { + metricsServer := newMetricsServer(cfg) + servers = append(servers, metricsServer) + + g.Go(func() error { + log.Debugf("metrics: listening on %s", cfg.MetricsBindAddress) + if err := metricsServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("metrics: %w", err) + } + return nil + }) } - <-serverCtx.Done() - log.Infof("server: shutdown completed") - return nil + if cfg.ProbeBindAddress != "" { + probeServer := newProbeServer(cfg) + servers = append(servers, probeServer) + + g.Go(func() error { + log.Debugf("probe: listening on %s", cfg.ProbeBindAddress) + if err := probeServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("probe: %w", err) + } + return nil + }) + } + + g.Go(func() error { + <-gctx.Done() + + log.Infof("server: received shutdown signal; waiting for %s before starting graceful shutdown...", cfg.ShutdownWaitBeforePeriod) + time.Sleep(cfg.ShutdownWaitBeforePeriod) + + // the total terminationGracePeriodSeconds in Kubernetes starts immediately when SIGTERM is sent, + // so we need to subtract the wait-before period to exit before SIGKILL + shutdownTimeout := cfg.ShutdownGracefulPeriod - cfg.ShutdownWaitBeforePeriod + shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + + log.Infof("server: starting graceful shutdown (will timeout after %s)...", shutdownTimeout) + + var errs []error + for _, srv := range servers { + if err := srv.Shutdown(shutdownCtx); err != nil { + errs = append(errs, err) + } + } + + if err := errors.Join(errs...); err != nil { + return fmt.Errorf("graceful shutdown: %w", err) + } + + log.Infof("server: shutdown completed") + return nil + }) + + return g.Wait() +} + +func newMetricsServer(cfg *config.Config) *http.Server { + metrics.WithProvider(string(cfg.OpenID.Provider)) + metrics.Register() + metrics.InitLabels() + + return &http.Server{ + Addr: cfg.MetricsBindAddress, + Handler: promhttp.Handler(), + ReadHeaderTimeout: 10 * time.Second, + } +} + +func newProbeServer(cfg *config.Config) *http.Server { + mux := http.NewServeMux() + healthz := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + }) + mux.HandleFunc("/", healthz) + mux.HandleFunc("/healthz", healthz) + + if cfg.PprofEnabled { + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + log.Infof("pprof: enabled on %s/debug/pprof/", cfg.ProbeBindAddress) + } + + return &http.Server{ + Addr: cfg.ProbeBindAddress, + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + } }