mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-06 08:27:10 +00:00
refactor: clean up main
This commit is contained in:
@@ -1,20 +1,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/nais/liberator/pkg/conftools"
|
||||
"github.com/nais/liberator/pkg/keygen"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
@@ -23,6 +12,7 @@ import (
|
||||
"github.com/nais/wonderwall/pkg/metrics"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/router"
|
||||
"github.com/nais/wonderwall/pkg/server"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
@@ -51,18 +41,9 @@ func run() error {
|
||||
log.Info(line)
|
||||
}
|
||||
|
||||
key, err := base64.StdEncoding.DecodeString(cfg.EncryptionKey)
|
||||
key, err := crypto.EncryptionKeyOrGenerate(cfg)
|
||||
if err != nil {
|
||||
if len(cfg.EncryptionKey) > 0 {
|
||||
return fmt.Errorf("decode encryption key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(key) == 0 {
|
||||
key, err = keygen.Keygen(32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate random encryption key: %w", err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
prv, err := openid.NewProvider(cfg)
|
||||
@@ -71,7 +52,7 @@ func run() error {
|
||||
}
|
||||
|
||||
crypt := crypto.NewCrypter(key)
|
||||
sessionStore := setupSessionStore(cfg)
|
||||
sessionStore := session.NewStore(cfg)
|
||||
httplogger := logging.NewHttpLogger(cfg)
|
||||
h, err := router.NewHandler(*cfg, crypt, httplogger, prv, sessionStore)
|
||||
if err != nil {
|
||||
@@ -86,85 +67,7 @@ func run() error {
|
||||
log.Fatalf("fatal: metrics server error: %s", err)
|
||||
}
|
||||
}()
|
||||
return startServer(cfg, r)
|
||||
}
|
||||
|
||||
func setupSessionStore(cfg *config.Config) session.Store {
|
||||
if len(cfg.Redis.Address) == 0 {
|
||||
log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!")
|
||||
return session.NewMemory()
|
||||
}
|
||||
|
||||
redisClient, err := configureRedisClient(cfg)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to configure Redis: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer cancel()
|
||||
|
||||
err = redisClient.Ping(ctx).Err()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to connect to configured Redis, using cookie fallback: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("Using Redis as session backing store")
|
||||
return session.NewRedis(redisClient)
|
||||
}
|
||||
|
||||
func configureRedisClient(cfg *config.Config) (*redis.Client, error) {
|
||||
opts := &redis.Options{
|
||||
Network: "tcp",
|
||||
Addr: cfg.Redis.Address,
|
||||
Username: cfg.Redis.Username,
|
||||
Password: cfg.Redis.Password,
|
||||
}
|
||||
|
||||
if cfg.Redis.TLS {
|
||||
opts.TLSConfig = &tls.Config{}
|
||||
}
|
||||
|
||||
redisClient := redis.NewClient(opts)
|
||||
return redisClient, nil
|
||||
}
|
||||
|
||||
func startServer(cfg *config.Config, r chi.Router) error {
|
||||
server := http.Server{
|
||||
Addr: cfg.BindAddress,
|
||||
Handler: r,
|
||||
}
|
||||
|
||||
serverCtx, serverStopCtx := context.WithCancel(context.Background())
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
||||
go func() {
|
||||
<-sig
|
||||
|
||||
shutdownCtx, shutdownStopCtx := context.WithTimeout(serverCtx, 20*time.Second)
|
||||
|
||||
go func() {
|
||||
<-shutdownCtx.Done()
|
||||
if shutdownCtx.Err() == context.DeadlineExceeded {
|
||||
log.Fatal("graceful shutdown timed out.. forcing exit.")
|
||||
}
|
||||
}()
|
||||
|
||||
err := server.Shutdown(shutdownCtx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
shutdownStopCtx()
|
||||
serverStopCtx()
|
||||
}()
|
||||
|
||||
err := server.ListenAndServe()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
|
||||
<-serverCtx.Done()
|
||||
return nil
|
||||
return server.Start(cfg, r)
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
flag "github.com/spf13/pflag"
|
||||
)
|
||||
|
||||
@@ -18,6 +21,22 @@ type Redis struct {
|
||||
TLS bool `json:"tls"`
|
||||
}
|
||||
|
||||
func (r *Redis) Client() (*redis.Client, error) {
|
||||
opts := &redis.Options{
|
||||
Network: "tcp",
|
||||
Addr: r.Address,
|
||||
Username: r.Username,
|
||||
Password: r.Password,
|
||||
}
|
||||
|
||||
if r.TLS {
|
||||
opts.TLSConfig = &tls.Config{}
|
||||
}
|
||||
|
||||
redisClient := redis.NewClient(opts)
|
||||
return redisClient, nil
|
||||
}
|
||||
|
||||
func redisFlags() {
|
||||
flag.String(RedisAddress, "", "Address of Redis. An empty value will use in-memory session storage.")
|
||||
flag.String(RedisPassword, "", "Password for Redis.")
|
||||
|
||||
@@ -4,11 +4,14 @@ import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/nais/liberator/pkg/keygen"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
)
|
||||
|
||||
type crypter struct {
|
||||
@@ -26,6 +29,24 @@ func NewCrypter(key []byte) Crypter {
|
||||
}
|
||||
}
|
||||
|
||||
func EncryptionKeyOrGenerate(cfg *config.Config) ([]byte, error) {
|
||||
key, err := base64.StdEncoding.DecodeString(cfg.EncryptionKey)
|
||||
if err != nil {
|
||||
if len(cfg.EncryptionKey) > 0 {
|
||||
return nil, fmt.Errorf("decode encryption key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(key) == 0 {
|
||||
key, err = keygen.Keygen(32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate random encryption key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// Generate an initialization vector for encryption.
|
||||
// It consists of the current UNIX timestamp with nanoseconds, and four bytes of randomness.
|
||||
func IV() ([]byte, error) {
|
||||
|
||||
54
pkg/server/server.go
Normal file
54
pkg/server/server.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
)
|
||||
|
||||
func Start(cfg *config.Config, r chi.Router) error {
|
||||
server := http.Server{
|
||||
Addr: cfg.BindAddress,
|
||||
Handler: r,
|
||||
}
|
||||
|
||||
serverCtx, serverStopCtx := context.WithCancel(context.Background())
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
||||
go func() {
|
||||
<-sig
|
||||
|
||||
shutdownCtx, shutdownStopCtx := context.WithTimeout(serverCtx, 20*time.Second)
|
||||
|
||||
go func() {
|
||||
<-shutdownCtx.Done()
|
||||
if shutdownCtx.Err() == context.DeadlineExceeded {
|
||||
log.Fatal("graceful shutdown timed out.. forcing exit.")
|
||||
}
|
||||
}()
|
||||
|
||||
err := server.Shutdown(shutdownCtx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
shutdownStopCtx()
|
||||
serverStopCtx()
|
||||
}()
|
||||
|
||||
err := server.ListenAndServe()
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
return err
|
||||
}
|
||||
|
||||
<-serverCtx.Done()
|
||||
return nil
|
||||
}
|
||||
@@ -7,6 +7,9 @@ import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/crypto"
|
||||
)
|
||||
|
||||
@@ -16,6 +19,30 @@ type Store interface {
|
||||
Delete(ctx context.Context, keys ...string) error
|
||||
}
|
||||
|
||||
func NewStore(cfg *config.Config) Store {
|
||||
if len(cfg.Redis.Address) == 0 {
|
||||
log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!")
|
||||
return NewMemory()
|
||||
}
|
||||
|
||||
redisClient, err := cfg.Redis.Client()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to configure Redis: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer cancel()
|
||||
|
||||
err = redisClient.Ping(ctx).Err()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to connect to configured Redis, using cookie fallback: %v", err)
|
||||
} else {
|
||||
log.Infof("Using Redis as session backing store")
|
||||
}
|
||||
|
||||
return NewRedis(redisClient)
|
||||
}
|
||||
|
||||
type EncryptedData struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user