mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-20 23:32:57 +00:00
refactor: use errgroup for coordinated server lifecycle
Move metrics and probe listeners from main.go into server.go and manage all servers with errgroup. This replaces log.Fatalf goroutines with proper error propagation and ties all server lifetimes together so a failure in any listener triggers graceful shutdown of the others.
This commit is contained in:
@@ -3,8 +3,6 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/pprof"
|
||||
|
||||
"github.com/KimMachineGun/automemlimit/memlimit"
|
||||
"github.com/nais/wonderwall/internal/crypto"
|
||||
@@ -12,7 +10,6 @@ import (
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/handler"
|
||||
"github.com/nais/wonderwall/pkg/metrics"
|
||||
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
|
||||
"github.com/nais/wonderwall/pkg/openid/provider"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
@@ -84,44 +81,7 @@ func run() error {
|
||||
|
||||
r := router.New(src, cfg)
|
||||
|
||||
if cfg.MetricsBindAddress != "" {
|
||||
go func() {
|
||||
log.Debugf("metrics: listening on %s", cfg.MetricsBindAddress)
|
||||
err := metrics.Handle(cfg.MetricsBindAddress, cfg.OpenID.Provider)
|
||||
if err != nil {
|
||||
log.Fatalf("fatal: metrics server error: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if cfg.ProbeBindAddress != "" {
|
||||
go func() {
|
||||
mux := http.NewServeMux()
|
||||
healthz := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("ok"))
|
||||
})
|
||||
mux.HandleFunc("/", healthz)
|
||||
mux.HandleFunc("/healthz", healthz)
|
||||
|
||||
if cfg.PprofEnabled {
|
||||
mux.HandleFunc("/debug/pprof/", pprof.Index)
|
||||
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
||||
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
||||
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
||||
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
||||
log.Infof("pprof: enabled on %s/debug/pprof/", cfg.ProbeBindAddress)
|
||||
}
|
||||
|
||||
log.Debugf("probe: listening on %s", cfg.ProbeBindAddress)
|
||||
err := http.ListenAndServe(cfg.ProbeBindAddress, mux)
|
||||
if err != nil {
|
||||
log.Fatalf("fatal: probe server error: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
return server.Start(cfg, r)
|
||||
return server.Start(ctx, cfg, r)
|
||||
}
|
||||
|
||||
func standalone(ctx context.Context, cfg *config.Config, crypt crypto.Crypter) (*handler.Standalone, error) {
|
||||
|
||||
2
go.mod
2
go.mod
@@ -40,6 +40,7 @@ require (
|
||||
go.opentelemetry.io/otel/trace v1.43.0
|
||||
golang.org/x/crypto v0.50.0
|
||||
golang.org/x/oauth2 v0.36.0
|
||||
golang.org/x/sync v0.20.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -102,7 +103,6 @@ require (
|
||||
golang.org/x/exp/typeparams v0.0.0-20231108232855-2478ac86f678 // indirect
|
||||
golang.org/x/mod v0.34.0 // indirect
|
||||
golang.org/x/net v0.52.0 // indirect
|
||||
golang.org/x/sync v0.20.0 // indirect
|
||||
golang.org/x/sys v0.43.0 // indirect
|
||||
golang.org/x/telemetry v0.0.0-20260311193753-579e4da9a98c // indirect
|
||||
golang.org/x/text v0.36.0 // indirect
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -127,17 +123,8 @@ func InitLabels() {
|
||||
Logins.With(prometheus.Labels{LabelAmr: "", LabelRedirect: ""})
|
||||
}
|
||||
|
||||
func Handle(address string, provider config.Provider) error {
|
||||
WithProvider(string(provider))
|
||||
Register(prometheus.DefaultRegisterer)
|
||||
InitLabels()
|
||||
|
||||
handler := promhttp.Handler()
|
||||
return http.ListenAndServe(address, handler)
|
||||
}
|
||||
|
||||
func Register(registry prometheus.Registerer) {
|
||||
registry.MustRegister(
|
||||
func Register() {
|
||||
prometheus.DefaultRegisterer.MustRegister(
|
||||
RedisLatency,
|
||||
Logins,
|
||||
Logouts,
|
||||
|
||||
@@ -3,20 +3,27 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"net/http/pprof"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/metrics"
|
||||
)
|
||||
|
||||
func Start(cfg *config.Config, r chi.Router) error {
|
||||
server := http.Server{
|
||||
func Start(ctx context.Context, cfg *config.Config, r chi.Router) error {
|
||||
ctx, stop := signal.NotifyContext(ctx, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
||||
defer stop()
|
||||
|
||||
mainServer := &http.Server{
|
||||
Addr: cfg.BindAddress,
|
||||
Handler: r,
|
||||
ReadHeaderTimeout: 10 * time.Second, // Prevents slowloris attacks (connections held open without sending headers).
|
||||
@@ -25,42 +32,109 @@ func Start(cfg *config.Config, r chi.Router) error {
|
||||
// ReadTimeout/WriteTimeout intentionally omitted - a reverse proxy must support slow transfers.
|
||||
}
|
||||
|
||||
serverCtx, serverStopCtx := context.WithCancel(context.Background())
|
||||
servers := []*http.Server{mainServer}
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
||||
go func() {
|
||||
s := <-sig
|
||||
log.Infof("server: received %q; waiting for %s before starting graceful shutdown...", s, cfg.ShutdownWaitBeforePeriod)
|
||||
time.Sleep(cfg.ShutdownWaitBeforePeriod)
|
||||
g, gctx := errgroup.WithContext(ctx)
|
||||
|
||||
// the total terminationGracePeriodSeconds in Kubernetes starts immediately when SIGTERM is sent, so we need to subtract the wait-before period to exit before SIGKILL
|
||||
shutdownTimeout := cfg.ShutdownGracefulPeriod - cfg.ShutdownWaitBeforePeriod
|
||||
shutdownCtx, shutdownStopCtx := context.WithTimeout(serverCtx, shutdownTimeout)
|
||||
|
||||
go func() {
|
||||
<-shutdownCtx.Done()
|
||||
if errors.Is(shutdownCtx.Err(), context.DeadlineExceeded) {
|
||||
log.Fatalf("server: graceful shutdown timed out after %s; forcing exit.", shutdownTimeout)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Infof("server: starting graceful shutdown (will timeout after %s)...", shutdownTimeout)
|
||||
err := server.Shutdown(shutdownCtx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
g.Go(func() error {
|
||||
log.Infof("server: listening on %s", cfg.BindAddress)
|
||||
if err := mainServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("server: %w", err)
|
||||
}
|
||||
shutdownStopCtx()
|
||||
serverStopCtx()
|
||||
}()
|
||||
return nil
|
||||
})
|
||||
|
||||
log.Infof("server: listening on %s", cfg.BindAddress)
|
||||
err := server.ListenAndServe()
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
if cfg.MetricsBindAddress != "" {
|
||||
metricsServer := newMetricsServer(cfg)
|
||||
servers = append(servers, metricsServer)
|
||||
|
||||
g.Go(func() error {
|
||||
log.Debugf("metrics: listening on %s", cfg.MetricsBindAddress)
|
||||
if err := metricsServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("metrics: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
<-serverCtx.Done()
|
||||
log.Infof("server: shutdown completed")
|
||||
return nil
|
||||
if cfg.ProbeBindAddress != "" {
|
||||
probeServer := newProbeServer(cfg)
|
||||
servers = append(servers, probeServer)
|
||||
|
||||
g.Go(func() error {
|
||||
log.Debugf("probe: listening on %s", cfg.ProbeBindAddress)
|
||||
if err := probeServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("probe: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
g.Go(func() error {
|
||||
<-gctx.Done()
|
||||
|
||||
log.Infof("server: received shutdown signal; waiting for %s before starting graceful shutdown...", cfg.ShutdownWaitBeforePeriod)
|
||||
time.Sleep(cfg.ShutdownWaitBeforePeriod)
|
||||
|
||||
// the total terminationGracePeriodSeconds in Kubernetes starts immediately when SIGTERM is sent,
|
||||
// so we need to subtract the wait-before period to exit before SIGKILL
|
||||
shutdownTimeout := cfg.ShutdownGracefulPeriod - cfg.ShutdownWaitBeforePeriod
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
defer cancel()
|
||||
|
||||
log.Infof("server: starting graceful shutdown (will timeout after %s)...", shutdownTimeout)
|
||||
|
||||
var errs []error
|
||||
for _, srv := range servers {
|
||||
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := errors.Join(errs...); err != nil {
|
||||
return fmt.Errorf("graceful shutdown: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("server: shutdown completed")
|
||||
return nil
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func newMetricsServer(cfg *config.Config) *http.Server {
|
||||
metrics.WithProvider(string(cfg.OpenID.Provider))
|
||||
metrics.Register()
|
||||
metrics.InitLabels()
|
||||
|
||||
return &http.Server{
|
||||
Addr: cfg.MetricsBindAddress,
|
||||
Handler: promhttp.Handler(),
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func newProbeServer(cfg *config.Config) *http.Server {
|
||||
mux := http.NewServeMux()
|
||||
healthz := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("ok"))
|
||||
})
|
||||
mux.HandleFunc("/", healthz)
|
||||
mux.HandleFunc("/healthz", healthz)
|
||||
|
||||
if cfg.PprofEnabled {
|
||||
mux.HandleFunc("/debug/pprof/", pprof.Index)
|
||||
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
|
||||
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
|
||||
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
|
||||
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
|
||||
log.Infof("pprof: enabled on %s/debug/pprof/", cfg.ProbeBindAddress)
|
||||
}
|
||||
|
||||
return &http.Server{
|
||||
Addr: cfg.ProbeBindAddress,
|
||||
Handler: mux,
|
||||
ReadHeaderTimeout: 10 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user