From c70037bd4c8a87d6c19f88d2911bbe64b189737d Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Fri, 22 Oct 2021 09:05:06 +0200 Subject: [PATCH] refactor: clean up main --- cmd/wonderwall/main.go | 107 ++--------------------------------------- pkg/config/redis.go | 19 ++++++++ pkg/crypto/crypter.go | 21 ++++++++ pkg/server/server.go | 54 +++++++++++++++++++++ pkg/session/session.go | 27 +++++++++++ 5 files changed, 126 insertions(+), 102 deletions(-) create mode 100644 pkg/server/server.go diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index 3baa6c0..717847d 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -1,20 +1,9 @@ package main import ( - "context" - "crypto/tls" - "encoding/base64" "fmt" - "net/http" - "os" - "os/signal" - "syscall" - "time" - "github.com/go-chi/chi/v5" - "github.com/go-redis/redis/v8" "github.com/nais/liberator/pkg/conftools" - "github.com/nais/liberator/pkg/keygen" log "github.com/sirupsen/logrus" "github.com/nais/wonderwall/pkg/config" @@ -23,6 +12,7 @@ import ( "github.com/nais/wonderwall/pkg/metrics" "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/router" + "github.com/nais/wonderwall/pkg/server" "github.com/nais/wonderwall/pkg/session" ) @@ -51,18 +41,9 @@ func run() error { log.Info(line) } - key, err := base64.StdEncoding.DecodeString(cfg.EncryptionKey) + key, err := crypto.EncryptionKeyOrGenerate(cfg) if err != nil { - if len(cfg.EncryptionKey) > 0 { - return fmt.Errorf("decode encryption key: %w", err) - } - } - - if len(key) == 0 { - key, err = keygen.Keygen(32) - if err != nil { - return fmt.Errorf("generate random encryption key: %w", err) - } + return err } prv, err := openid.NewProvider(cfg) @@ -71,7 +52,7 @@ func run() error { } crypt := crypto.NewCrypter(key) - sessionStore := setupSessionStore(cfg) + sessionStore := session.NewStore(cfg) httplogger := logging.NewHttpLogger(cfg) h, err := router.NewHandler(*cfg, crypt, httplogger, prv, sessionStore) if err != nil { @@ -86,85 +67,7 @@ func run() error { log.Fatalf("fatal: metrics server error: %s", err) } }() - return startServer(cfg, r) -} - -func setupSessionStore(cfg *config.Config) session.Store { - if len(cfg.Redis.Address) == 0 { - log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!") - return session.NewMemory() - } - - redisClient, err := configureRedisClient(cfg) - if err != nil { - log.Fatalf("Failed to configure Redis: %v", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - - err = redisClient.Ping(ctx).Err() - if err != nil { - log.Warnf("Failed to connect to configured Redis, using cookie fallback: %v", err) - } - - log.Infof("Using Redis as session backing store") - return session.NewRedis(redisClient) -} - -func configureRedisClient(cfg *config.Config) (*redis.Client, error) { - opts := &redis.Options{ - Network: "tcp", - Addr: cfg.Redis.Address, - Username: cfg.Redis.Username, - Password: cfg.Redis.Password, - } - - if cfg.Redis.TLS { - opts.TLSConfig = &tls.Config{} - } - - redisClient := redis.NewClient(opts) - return redisClient, nil -} - -func startServer(cfg *config.Config, r chi.Router) error { - server := http.Server{ - Addr: cfg.BindAddress, - Handler: r, - } - - serverCtx, serverStopCtx := context.WithCancel(context.Background()) - - sig := make(chan os.Signal, 1) - signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) - go func() { - <-sig - - shutdownCtx, shutdownStopCtx := context.WithTimeout(serverCtx, 20*time.Second) - - go func() { - <-shutdownCtx.Done() - if shutdownCtx.Err() == context.DeadlineExceeded { - log.Fatal("graceful shutdown timed out.. forcing exit.") - } - }() - - err := server.Shutdown(shutdownCtx) - if err != nil { - log.Fatal(err) - } - shutdownStopCtx() - serverStopCtx() - }() - - err := server.ListenAndServe() - if err != nil && err != http.ErrServerClosed { - return err - } - - <-serverCtx.Done() - return nil + return server.Start(cfg, r) } func main() { diff --git a/pkg/config/redis.go b/pkg/config/redis.go index 24301f0..d3e544d 100644 --- a/pkg/config/redis.go +++ b/pkg/config/redis.go @@ -1,6 +1,9 @@ package config import ( + "crypto/tls" + + "github.com/go-redis/redis/v8" flag "github.com/spf13/pflag" ) @@ -18,6 +21,22 @@ type Redis struct { TLS bool `json:"tls"` } +func (r *Redis) Client() (*redis.Client, error) { + opts := &redis.Options{ + Network: "tcp", + Addr: r.Address, + Username: r.Username, + Password: r.Password, + } + + if r.TLS { + opts.TLSConfig = &tls.Config{} + } + + redisClient := redis.NewClient(opts) + return redisClient, nil +} + func redisFlags() { flag.String(RedisAddress, "", "Address of Redis. An empty value will use in-memory session storage.") flag.String(RedisPassword, "", "Password for Redis.") diff --git a/pkg/crypto/crypter.go b/pkg/crypto/crypter.go index f964613..312ddff 100644 --- a/pkg/crypto/crypter.go +++ b/pkg/crypto/crypter.go @@ -4,11 +4,14 @@ import ( "bytes" "crypto/aes" "crypto/cipher" + "encoding/base64" "encoding/binary" "fmt" "time" "github.com/nais/liberator/pkg/keygen" + + "github.com/nais/wonderwall/pkg/config" ) type crypter struct { @@ -26,6 +29,24 @@ func NewCrypter(key []byte) Crypter { } } +func EncryptionKeyOrGenerate(cfg *config.Config) ([]byte, error) { + key, err := base64.StdEncoding.DecodeString(cfg.EncryptionKey) + if err != nil { + if len(cfg.EncryptionKey) > 0 { + return nil, fmt.Errorf("decode encryption key: %w", err) + } + } + + if len(key) == 0 { + key, err = keygen.Keygen(32) + if err != nil { + return nil, fmt.Errorf("generate random encryption key: %w", err) + } + } + + return key, nil +} + // Generate an initialization vector for encryption. // It consists of the current UNIX timestamp with nanoseconds, and four bytes of randomness. func IV() ([]byte, error) { diff --git a/pkg/server/server.go b/pkg/server/server.go new file mode 100644 index 0000000..0ad2553 --- /dev/null +++ b/pkg/server/server.go @@ -0,0 +1,54 @@ +package server + +import ( + "context" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/go-chi/chi/v5" + log "github.com/sirupsen/logrus" + + "github.com/nais/wonderwall/pkg/config" +) + +func Start(cfg *config.Config, r chi.Router) error { + server := http.Server{ + Addr: cfg.BindAddress, + Handler: r, + } + + serverCtx, serverStopCtx := context.WithCancel(context.Background()) + + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + go func() { + <-sig + + shutdownCtx, shutdownStopCtx := context.WithTimeout(serverCtx, 20*time.Second) + + go func() { + <-shutdownCtx.Done() + if shutdownCtx.Err() == context.DeadlineExceeded { + log.Fatal("graceful shutdown timed out.. forcing exit.") + } + }() + + err := server.Shutdown(shutdownCtx) + if err != nil { + log.Fatal(err) + } + shutdownStopCtx() + serverStopCtx() + }() + + err := server.ListenAndServe() + if err != nil && err != http.ErrServerClosed { + return err + } + + <-serverCtx.Done() + return nil +} diff --git a/pkg/session/session.go b/pkg/session/session.go index 6d4925d..0557010 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -7,6 +7,9 @@ import ( "encoding/json" "time" + log "github.com/sirupsen/logrus" + + "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/crypto" ) @@ -16,6 +19,30 @@ type Store interface { Delete(ctx context.Context, keys ...string) error } +func NewStore(cfg *config.Config) Store { + if len(cfg.Redis.Address) == 0 { + log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!") + return NewMemory() + } + + redisClient, err := cfg.Redis.Client() + if err != nil { + log.Fatalf("Failed to configure Redis: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + err = redisClient.Ping(ctx).Err() + if err != nil { + log.Warnf("Failed to connect to configured Redis, using cookie fallback: %v", err) + } else { + log.Infof("Using Redis as session backing store") + } + + return NewRedis(redisClient) +} + type EncryptedData struct { Data string `json:"data"` }