diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index 9a77791..16fbf80 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -3,8 +3,6 @@ package main import ( "context" "fmt" - "net/http" - "net/http/pprof" "github.com/KimMachineGun/automemlimit/memlimit" "github.com/nais/wonderwall/internal/crypto" @@ -12,7 +10,6 @@ import ( "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cookie" "github.com/nais/wonderwall/pkg/handler" - "github.com/nais/wonderwall/pkg/metrics" openidconfig "github.com/nais/wonderwall/pkg/openid/config" "github.com/nais/wonderwall/pkg/openid/provider" "github.com/nais/wonderwall/pkg/router" @@ -84,44 +81,7 @@ func run() error { r := router.New(src, cfg) - if cfg.MetricsBindAddress != "" { - go func() { - log.Debugf("metrics: listening on %s", cfg.MetricsBindAddress) - err := metrics.Handle(cfg.MetricsBindAddress, cfg.OpenID.Provider) - if err != nil { - log.Fatalf("fatal: metrics server error: %s", err) - } - }() - } - - if cfg.ProbeBindAddress != "" { - go func() { - mux := http.NewServeMux() - healthz := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("ok")) - }) - mux.HandleFunc("/", healthz) - mux.HandleFunc("/healthz", healthz) - - if cfg.PprofEnabled { - mux.HandleFunc("/debug/pprof/", pprof.Index) - mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) - mux.HandleFunc("/debug/pprof/profile", pprof.Profile) - mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) - mux.HandleFunc("/debug/pprof/trace", pprof.Trace) - log.Infof("pprof: enabled on %s/debug/pprof/", cfg.ProbeBindAddress) - } - - log.Debugf("probe: listening on %s", cfg.ProbeBindAddress) - err := http.ListenAndServe(cfg.ProbeBindAddress, mux) - if err != nil { - log.Fatalf("fatal: probe server error: %s", err) - } - }() - } - - return server.Start(cfg, r) + return server.Start(ctx, cfg, r) } func standalone(ctx context.Context, cfg *config.Config, crypt crypto.Crypter) (*handler.Standalone, error) { diff --git a/go.mod b/go.mod index a8ebf97..d0fff15 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,7 @@ require ( go.opentelemetry.io/otel/trace v1.43.0 golang.org/x/crypto v0.50.0 golang.org/x/oauth2 v0.36.0 + golang.org/x/sync v0.20.0 ) require ( @@ -102,7 +103,6 @@ require ( golang.org/x/exp/typeparams v0.0.0-20231108232855-2478ac86f678 // indirect golang.org/x/mod v0.34.0 // indirect golang.org/x/net v0.52.0 // indirect - golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.43.0 // indirect golang.org/x/telemetry v0.0.0-20260311193753-579e4da9a98c // indirect golang.org/x/text v0.36.0 // indirect diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 82ad634..0a59974 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -1,15 +1,11 @@ package metrics import ( - "net/http" "net/url" "strings" "time" "github.com/prometheus/client_golang/prometheus" - "github.com/prometheus/client_golang/prometheus/promhttp" - - "github.com/nais/wonderwall/pkg/config" ) const ( @@ -127,17 +123,8 @@ func InitLabels() { Logins.With(prometheus.Labels{LabelAmr: "", LabelRedirect: ""}) } -func Handle(address string, provider config.Provider) error { - WithProvider(string(provider)) - Register(prometheus.DefaultRegisterer) - InitLabels() - - handler := promhttp.Handler() - return http.ListenAndServe(address, handler) -} - -func Register(registry prometheus.Registerer) { - registry.MustRegister( +func Register() { + prometheus.DefaultRegisterer.MustRegister( RedisLatency, Logins, Logouts, diff --git a/pkg/server/server.go b/pkg/server/server.go index eef1d44..a8511ba 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -3,20 +3,27 @@ package server import ( "context" "errors" + "fmt" "net/http" - "os" + "net/http/pprof" "os/signal" "syscall" "time" "github.com/go-chi/chi/v5" + "github.com/prometheus/client_golang/prometheus/promhttp" log "github.com/sirupsen/logrus" + "golang.org/x/sync/errgroup" "github.com/nais/wonderwall/pkg/config" + "github.com/nais/wonderwall/pkg/metrics" ) -func Start(cfg *config.Config, r chi.Router) error { - server := http.Server{ +func Start(ctx context.Context, cfg *config.Config, r chi.Router) error { + ctx, stop := signal.NotifyContext(ctx, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + defer stop() + + mainServer := &http.Server{ Addr: cfg.BindAddress, Handler: r, ReadHeaderTimeout: 10 * time.Second, // Prevents slowloris attacks (connections held open without sending headers). @@ -25,42 +32,109 @@ func Start(cfg *config.Config, r chi.Router) error { // ReadTimeout/WriteTimeout intentionally omitted - a reverse proxy must support slow transfers. } - serverCtx, serverStopCtx := context.WithCancel(context.Background()) + servers := []*http.Server{mainServer} - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) - go func() { - s := <-sig - log.Infof("server: received %q; waiting for %s before starting graceful shutdown...", s, cfg.ShutdownWaitBeforePeriod) - time.Sleep(cfg.ShutdownWaitBeforePeriod) + g, gctx := errgroup.WithContext(ctx) - // the total terminationGracePeriodSeconds in Kubernetes starts immediately when SIGTERM is sent, so we need to subtract the wait-before period to exit before SIGKILL - shutdownTimeout := cfg.ShutdownGracefulPeriod - cfg.ShutdownWaitBeforePeriod - shutdownCtx, shutdownStopCtx := context.WithTimeout(serverCtx, shutdownTimeout) - - go func() { - <-shutdownCtx.Done() - if errors.Is(shutdownCtx.Err(), context.DeadlineExceeded) { - log.Fatalf("server: graceful shutdown timed out after %s; forcing exit.", shutdownTimeout) - } - }() - - log.Infof("server: starting graceful shutdown (will timeout after %s)...", shutdownTimeout) - err := server.Shutdown(shutdownCtx) - if err != nil { - log.Fatal(err) + g.Go(func() error { + log.Infof("server: listening on %s", cfg.BindAddress) + if err := mainServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("server: %w", err) } - shutdownStopCtx() - serverStopCtx() - }() + return nil + }) - log.Infof("server: listening on %s", cfg.BindAddress) - err := server.ListenAndServe() - if err != nil && !errors.Is(err, http.ErrServerClosed) { - return err + if cfg.MetricsBindAddress != "" { + metricsServer := newMetricsServer(cfg) + servers = append(servers, metricsServer) + + g.Go(func() error { + log.Debugf("metrics: listening on %s", cfg.MetricsBindAddress) + if err := metricsServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("metrics: %w", err) + } + return nil + }) } - <-serverCtx.Done() - log.Infof("server: shutdown completed") - return nil + if cfg.ProbeBindAddress != "" { + probeServer := newProbeServer(cfg) + servers = append(servers, probeServer) + + g.Go(func() error { + log.Debugf("probe: listening on %s", cfg.ProbeBindAddress) + if err := probeServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("probe: %w", err) + } + return nil + }) + } + + g.Go(func() error { + <-gctx.Done() + + log.Infof("server: received shutdown signal; waiting for %s before starting graceful shutdown...", cfg.ShutdownWaitBeforePeriod) + time.Sleep(cfg.ShutdownWaitBeforePeriod) + + // the total terminationGracePeriodSeconds in Kubernetes starts immediately when SIGTERM is sent, + // so we need to subtract the wait-before period to exit before SIGKILL + shutdownTimeout := cfg.ShutdownGracefulPeriod - cfg.ShutdownWaitBeforePeriod + shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + + log.Infof("server: starting graceful shutdown (will timeout after %s)...", shutdownTimeout) + + var errs []error + for _, srv := range servers { + if err := srv.Shutdown(shutdownCtx); err != nil { + errs = append(errs, err) + } + } + + if err := errors.Join(errs...); err != nil { + return fmt.Errorf("graceful shutdown: %w", err) + } + + log.Infof("server: shutdown completed") + return nil + }) + + return g.Wait() +} + +func newMetricsServer(cfg *config.Config) *http.Server { + metrics.WithProvider(string(cfg.OpenID.Provider)) + metrics.Register() + metrics.InitLabels() + + return &http.Server{ + Addr: cfg.MetricsBindAddress, + Handler: promhttp.Handler(), + ReadHeaderTimeout: 10 * time.Second, + } +} + +func newProbeServer(cfg *config.Config) *http.Server { + mux := http.NewServeMux() + healthz := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) + }) + mux.HandleFunc("/", healthz) + mux.HandleFunc("/healthz", healthz) + + if cfg.PprofEnabled { + mux.HandleFunc("/debug/pprof/", pprof.Index) + mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline) + mux.HandleFunc("/debug/pprof/profile", pprof.Profile) + mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol) + mux.HandleFunc("/debug/pprof/trace", pprof.Trace) + log.Infof("pprof: enabled on %s/debug/pprof/", cfg.ProbeBindAddress) + } + + return &http.Server{ + Addr: cfg.ProbeBindAddress, + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + } }