refactor: use errgroup for coordinated server lifecycle

Move metrics and probe listeners from main.go into server.go and
manage all servers with errgroup. This replaces log.Fatalf goroutines
with proper error propagation and ties all server lifetimes together
so a failure in any listener triggers graceful shutdown of the others.
This commit is contained in:
Trong Huu Nguyen
2026-05-08 07:49:17 +02:00
parent b1e3732ec3
commit 66d734358d
4 changed files with 113 additions and 92 deletions

View File

@@ -3,8 +3,6 @@ package main
import (
"context"
"fmt"
"net/http"
"net/http/pprof"
"github.com/KimMachineGun/automemlimit/memlimit"
"github.com/nais/wonderwall/internal/crypto"
@@ -12,7 +10,6 @@ import (
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/handler"
"github.com/nais/wonderwall/pkg/metrics"
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
"github.com/nais/wonderwall/pkg/openid/provider"
"github.com/nais/wonderwall/pkg/router"
@@ -84,44 +81,7 @@ func run() error {
r := router.New(src, cfg)
if cfg.MetricsBindAddress != "" {
go func() {
log.Debugf("metrics: listening on %s", cfg.MetricsBindAddress)
err := metrics.Handle(cfg.MetricsBindAddress, cfg.OpenID.Provider)
if err != nil {
log.Fatalf("fatal: metrics server error: %s", err)
}
}()
}
if cfg.ProbeBindAddress != "" {
go func() {
mux := http.NewServeMux()
healthz := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
})
mux.HandleFunc("/", healthz)
mux.HandleFunc("/healthz", healthz)
if cfg.PprofEnabled {
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
log.Infof("pprof: enabled on %s/debug/pprof/", cfg.ProbeBindAddress)
}
log.Debugf("probe: listening on %s", cfg.ProbeBindAddress)
err := http.ListenAndServe(cfg.ProbeBindAddress, mux)
if err != nil {
log.Fatalf("fatal: probe server error: %s", err)
}
}()
}
return server.Start(cfg, r)
return server.Start(ctx, cfg, r)
}
func standalone(ctx context.Context, cfg *config.Config, crypt crypto.Crypter) (*handler.Standalone, error) {

2
go.mod
View File

@@ -40,6 +40,7 @@ require (
go.opentelemetry.io/otel/trace v1.43.0
golang.org/x/crypto v0.50.0
golang.org/x/oauth2 v0.36.0
golang.org/x/sync v0.20.0
)
require (
@@ -102,7 +103,6 @@ require (
golang.org/x/exp/typeparams v0.0.0-20231108232855-2478ac86f678 // indirect
golang.org/x/mod v0.34.0 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/sync v0.20.0 // indirect
golang.org/x/sys v0.43.0 // indirect
golang.org/x/telemetry v0.0.0-20260311193753-579e4da9a98c // indirect
golang.org/x/text v0.36.0 // indirect

View File

@@ -1,15 +1,11 @@
package metrics
import (
"net/http"
"net/url"
"strings"
"time"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/nais/wonderwall/pkg/config"
)
const (
@@ -127,17 +123,8 @@ func InitLabels() {
Logins.With(prometheus.Labels{LabelAmr: "", LabelRedirect: ""})
}
func Handle(address string, provider config.Provider) error {
WithProvider(string(provider))
Register(prometheus.DefaultRegisterer)
InitLabels()
handler := promhttp.Handler()
return http.ListenAndServe(address, handler)
}
func Register(registry prometheus.Registerer) {
registry.MustRegister(
func Register() {
prometheus.DefaultRegisterer.MustRegister(
RedisLatency,
Logins,
Logouts,

View File

@@ -3,20 +3,27 @@ package server
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"net/http/pprof"
"os/signal"
"syscall"
"time"
"github.com/go-chi/chi/v5"
"github.com/prometheus/client_golang/prometheus/promhttp"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/metrics"
)
func Start(cfg *config.Config, r chi.Router) error {
server := http.Server{
func Start(ctx context.Context, cfg *config.Config, r chi.Router) error {
ctx, stop := signal.NotifyContext(ctx, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
defer stop()
mainServer := &http.Server{
Addr: cfg.BindAddress,
Handler: r,
ReadHeaderTimeout: 10 * time.Second, // Prevents slowloris attacks (connections held open without sending headers).
@@ -25,42 +32,109 @@ func Start(cfg *config.Config, r chi.Router) error {
// ReadTimeout/WriteTimeout intentionally omitted - a reverse proxy must support slow transfers.
}
serverCtx, serverStopCtx := context.WithCancel(context.Background())
servers := []*http.Server{mainServer}
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
go func() {
s := <-sig
log.Infof("server: received %q; waiting for %s before starting graceful shutdown...", s, cfg.ShutdownWaitBeforePeriod)
time.Sleep(cfg.ShutdownWaitBeforePeriod)
g, gctx := errgroup.WithContext(ctx)
// the total terminationGracePeriodSeconds in Kubernetes starts immediately when SIGTERM is sent, so we need to subtract the wait-before period to exit before SIGKILL
shutdownTimeout := cfg.ShutdownGracefulPeriod - cfg.ShutdownWaitBeforePeriod
shutdownCtx, shutdownStopCtx := context.WithTimeout(serverCtx, shutdownTimeout)
go func() {
<-shutdownCtx.Done()
if errors.Is(shutdownCtx.Err(), context.DeadlineExceeded) {
log.Fatalf("server: graceful shutdown timed out after %s; forcing exit.", shutdownTimeout)
}
}()
log.Infof("server: starting graceful shutdown (will timeout after %s)...", shutdownTimeout)
err := server.Shutdown(shutdownCtx)
if err != nil {
log.Fatal(err)
g.Go(func() error {
log.Infof("server: listening on %s", cfg.BindAddress)
if err := mainServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("server: %w", err)
}
shutdownStopCtx()
serverStopCtx()
}()
return nil
})
log.Infof("server: listening on %s", cfg.BindAddress)
err := server.ListenAndServe()
if err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
if cfg.MetricsBindAddress != "" {
metricsServer := newMetricsServer(cfg)
servers = append(servers, metricsServer)
g.Go(func() error {
log.Debugf("metrics: listening on %s", cfg.MetricsBindAddress)
if err := metricsServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("metrics: %w", err)
}
return nil
})
}
<-serverCtx.Done()
log.Infof("server: shutdown completed")
return nil
if cfg.ProbeBindAddress != "" {
probeServer := newProbeServer(cfg)
servers = append(servers, probeServer)
g.Go(func() error {
log.Debugf("probe: listening on %s", cfg.ProbeBindAddress)
if err := probeServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("probe: %w", err)
}
return nil
})
}
g.Go(func() error {
<-gctx.Done()
log.Infof("server: received shutdown signal; waiting for %s before starting graceful shutdown...", cfg.ShutdownWaitBeforePeriod)
time.Sleep(cfg.ShutdownWaitBeforePeriod)
// the total terminationGracePeriodSeconds in Kubernetes starts immediately when SIGTERM is sent,
// so we need to subtract the wait-before period to exit before SIGKILL
shutdownTimeout := cfg.ShutdownGracefulPeriod - cfg.ShutdownWaitBeforePeriod
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()
log.Infof("server: starting graceful shutdown (will timeout after %s)...", shutdownTimeout)
var errs []error
for _, srv := range servers {
if err := srv.Shutdown(shutdownCtx); err != nil {
errs = append(errs, err)
}
}
if err := errors.Join(errs...); err != nil {
return fmt.Errorf("graceful shutdown: %w", err)
}
log.Infof("server: shutdown completed")
return nil
})
return g.Wait()
}
func newMetricsServer(cfg *config.Config) *http.Server {
metrics.WithProvider(string(cfg.OpenID.Provider))
metrics.Register()
metrics.InitLabels()
return &http.Server{
Addr: cfg.MetricsBindAddress,
Handler: promhttp.Handler(),
ReadHeaderTimeout: 10 * time.Second,
}
}
func newProbeServer(cfg *config.Config) *http.Server {
mux := http.NewServeMux()
healthz := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("ok"))
})
mux.HandleFunc("/", healthz)
mux.HandleFunc("/healthz", healthz)
if cfg.PprofEnabled {
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
log.Infof("pprof: enabled on %s/debug/pprof/", cfg.ProbeBindAddress)
}
return &http.Server{
Addr: cfg.ProbeBindAddress,
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
}
}