mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-16 21:36:46 +00:00
435 lines
16 KiB
Go
435 lines
16 KiB
Go
package config
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"net/url"
|
|
"runtime/debug"
|
|
"time"
|
|
|
|
"github.com/mitchellh/mapstructure"
|
|
"github.com/nais/liberator/pkg/conftools"
|
|
"github.com/redis/go-redis/v9"
|
|
log "github.com/sirupsen/logrus"
|
|
flag "github.com/spf13/pflag"
|
|
"github.com/spf13/viper"
|
|
|
|
"github.com/nais/wonderwall/pkg/logging"
|
|
"github.com/nais/wonderwall/pkg/openid/acr"
|
|
)
|
|
|
|
type Config struct {
|
|
BindAddress string `json:"bind-address"`
|
|
LogFormat string `json:"log-format"`
|
|
LogLevel string `json:"log-level"`
|
|
MetricsBindAddress string `json:"metrics-bind-address"`
|
|
ShutdownGracefulPeriod time.Duration `json:"shutdown-graceful-period"`
|
|
ShutdownWaitBeforePeriod time.Duration `json:"shutdown-wait-before-period"`
|
|
Version string `json:"version"`
|
|
|
|
AutoLogin bool `json:"auto-login"`
|
|
AutoLoginIgnorePaths []string `json:"auto-login-ignore-paths"`
|
|
CookiePrefix string `json:"cookie-prefix"`
|
|
EncryptionKey string `json:"encryption-key"`
|
|
Ingresses []string `json:"ingress"`
|
|
UpstreamHost string `json:"upstream-host"`
|
|
UpstreamIP string `json:"upstream-ip"`
|
|
UpstreamPort int `json:"upstream-port"`
|
|
|
|
OpenTelemetry OpenTelemetry `json:"otel"`
|
|
OpenID OpenID `json:"openid"`
|
|
Redis Redis `json:"redis"`
|
|
Session Session `json:"session"`
|
|
SSO SSO `json:"sso"`
|
|
}
|
|
|
|
type OpenID struct {
|
|
ACRValues string `json:"acr-values"`
|
|
Audiences []string `json:"audiences"`
|
|
ClientID string `json:"client-id"`
|
|
ClientJWK string `json:"client-jwk"`
|
|
PostLogoutRedirectURI string `json:"post-logout-redirect-uri"`
|
|
Provider Provider `json:"provider"`
|
|
ResourceIndicator string `json:"resource-indicator"`
|
|
Scopes []string `json:"scopes"`
|
|
UILocales string `json:"ui-locales"`
|
|
WellKnownURL string `json:"well-known-url"`
|
|
}
|
|
|
|
func (in OpenID) TrustedAudiences() map[string]bool {
|
|
m := make(map[string]bool)
|
|
m[in.ClientID] = true
|
|
for _, aud := range in.Audiences {
|
|
m[aud] = true
|
|
}
|
|
|
|
return m
|
|
}
|
|
|
|
type OpenTelemetry struct {
|
|
Enabled bool `json:"enabled"`
|
|
ServiceName string `json:"service-name"`
|
|
}
|
|
|
|
type Redis struct {
|
|
Address string `json:"address"`
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
TLS bool `json:"tls"`
|
|
URI string `json:"uri"`
|
|
ConnectionIdleTimeout int `json:"connection-idle-timeout"`
|
|
}
|
|
|
|
func (r *Redis) Client() (*redis.Client, error) {
|
|
opts := &redis.Options{
|
|
Network: "tcp",
|
|
Addr: r.Address,
|
|
}
|
|
|
|
if r.TLS {
|
|
opts.TLSConfig = &tls.Config{}
|
|
}
|
|
|
|
if r.URI != "" {
|
|
var err error
|
|
|
|
opts, err = redis.ParseURL(r.URI)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
opts.MinIdleConns = 1
|
|
opts.MaxRetries = 5
|
|
|
|
if r.Username != "" {
|
|
opts.Username = r.Username
|
|
}
|
|
|
|
if r.Password != "" {
|
|
opts.Password = r.Password
|
|
}
|
|
|
|
if r.ConnectionIdleTimeout > 0 {
|
|
opts.ConnMaxIdleTime = time.Duration(r.ConnectionIdleTimeout) * time.Second
|
|
} else if r.ConnectionIdleTimeout == -1 {
|
|
opts.ConnMaxIdleTime = -1
|
|
}
|
|
|
|
return redis.NewClient(opts), nil
|
|
}
|
|
|
|
type Session struct {
|
|
Inactivity bool `json:"inactivity"`
|
|
InactivityTimeout time.Duration `json:"inactivity-timeout"`
|
|
MaxLifetime time.Duration `json:"max-lifetime"`
|
|
Refresh bool `json:"refresh"`
|
|
RefreshAuto bool `json:"refresh-auto"`
|
|
}
|
|
|
|
type SSO struct {
|
|
Enabled bool `json:"enabled"`
|
|
Domain string `json:"domain"`
|
|
Mode SSOMode `json:"mode"`
|
|
SessionCookieName string `json:"session-cookie-name"`
|
|
ServerURL string `json:"server-url"`
|
|
ServerDefaultRedirectURL string `json:"server-default-redirect-url"`
|
|
}
|
|
|
|
func (in SSO) IsServer() bool {
|
|
return in.Enabled && in.Mode == SSOModeServer
|
|
}
|
|
|
|
type Provider string
|
|
|
|
const (
|
|
ProviderAzure Provider = "azure"
|
|
ProviderIDPorten Provider = "idporten"
|
|
ProviderOpenID Provider = "openid"
|
|
)
|
|
|
|
type SSOMode string
|
|
|
|
const (
|
|
SSOModeServer SSOMode = "server"
|
|
SSOModeProxy SSOMode = "proxy"
|
|
)
|
|
|
|
const (
|
|
BindAddress = "bind-address"
|
|
LogFormat = "log-format"
|
|
LogLevel = "log-level"
|
|
MetricsBindAddress = "metrics-bind-address"
|
|
ShutdownGracefulPeriod = "shutdown-graceful-period"
|
|
ShutdownWaitBeforePeriod = "shutdown-wait-before-period"
|
|
|
|
AutoLogin = "auto-login"
|
|
AutoLoginIgnorePaths = "auto-login-ignore-paths"
|
|
CookiePrefix = "cookie-prefix"
|
|
EncryptionKey = "encryption-key"
|
|
Ingress = "ingress"
|
|
UpstreamHost = "upstream-host"
|
|
UpstreamIP = "upstream-ip"
|
|
UpstreamPort = "upstream-port"
|
|
|
|
OpenIDACRValues = "openid.acr-values"
|
|
OpenIDAudiences = "openid.audiences"
|
|
OpenIDClientID = "openid.client-id"
|
|
OpenIDClientJWK = "openid.client-jwk"
|
|
OpenIDPostLogoutRedirectURI = "openid.post-logout-redirect-uri"
|
|
OpenIDProvider = "openid.provider"
|
|
OpenIDResourceIndicator = "openid.resource-indicator"
|
|
OpenIDScopes = "openid.scopes"
|
|
OpenIDUILocales = "openid.ui-locales"
|
|
OpenIDWellKnownURL = "openid.well-known-url"
|
|
|
|
OpenTelemetryEnabled = "otel.enabled"
|
|
OpenTelemetryServiceName = "otel.service-name"
|
|
|
|
RedisAddress = "redis.address"
|
|
RedisPassword = "redis.password"
|
|
RedisTLS = "redis.tls"
|
|
RedisUsername = "redis.username"
|
|
RedisURI = "redis.uri"
|
|
RedisConnectionIdleTimeout = "redis.connection-idle-timeout"
|
|
|
|
SessionInactivity = "session.inactivity"
|
|
SessionInactivityTimeout = "session.inactivity-timeout"
|
|
SessionMaxLifetime = "session.max-lifetime"
|
|
SessionRefresh = "session.refresh"
|
|
SessionRefreshAuto = "session.refresh-auto"
|
|
|
|
SSOEnabled = "sso.enabled"
|
|
SSODomain = "sso.domain"
|
|
SSOModeFlag = "sso.mode"
|
|
SSOServerDefaultRedirectURL = "sso.server-default-redirect-url"
|
|
SSOSessionCookieName = "sso.session-cookie-name"
|
|
SSOServerURL = "sso.server-url"
|
|
)
|
|
|
|
func Initialize() (*Config, error) {
|
|
conftools.Initialize("wonderwall")
|
|
|
|
flag.String(BindAddress, "127.0.0.1:3000", "Listen address for public connections.")
|
|
flag.String(LogFormat, "json", "Log format, either 'json' or 'text'.")
|
|
flag.String(LogLevel, "info", "Logging verbosity level.")
|
|
flag.String(MetricsBindAddress, "127.0.0.1:3001", "Listen address for metrics only.")
|
|
flag.Duration(ShutdownGracefulPeriod, 30*time.Second, "Graceful shutdown period when receiving a shutdown signal after which the server is forcibly exited.")
|
|
flag.Duration(ShutdownWaitBeforePeriod, 0*time.Second, "Wait period when receiving a shutdown signal before actually starting a graceful shutdown. Useful for allowing propagation of Endpoint updates in Kubernetes.")
|
|
|
|
flag.Bool(AutoLogin, false, "Enforce authentication if the user does not have a valid session for all matching upstream paths. Automatically redirects HTTP navigation requests to login, otherwise responds with 401 with the Location header set.")
|
|
flag.StringSlice(AutoLoginIgnorePaths, []string{}, "Comma separated list of absolute paths to ignore when 'auto-login' is enabled. Supports basic wildcard matching with glob-style asterisks. Invalid patterns are ignored.")
|
|
flag.String(CookiePrefix, "io.nais.wonderwall", "Prefix for cookie names.")
|
|
flag.String(EncryptionKey, "", "Base64 encoded 256-bit cookie encryption key; must be identical in instances that share session store.")
|
|
flag.StringSlice(Ingress, []string{}, "Comma separated list of ingresses used to access the main application.")
|
|
flag.String(UpstreamHost, "127.0.0.1:8080", "Address of upstream host.")
|
|
flag.String(UpstreamIP, "", "IP of upstream host. Overrides 'upstream-host' if set.")
|
|
flag.Int(UpstreamPort, 0, "Port of upstream host. Overrides 'upstream-host' if set.")
|
|
|
|
flag.String(OpenIDACRValues, "", "Space separated string that configures the default security level (acr_values) parameter for authorization requests.")
|
|
flag.StringSlice(OpenIDAudiences, []string{}, "List of additional trusted audiences (other than the client_id) for OpenID Connect id_token validation.")
|
|
flag.String(OpenIDClientID, "", "Client ID for the OpenID client.")
|
|
flag.String(OpenIDClientJWK, "", "JWK containing the private key for the OpenID client in string format.")
|
|
flag.String(OpenIDPostLogoutRedirectURI, "", "URI for redirecting the user after successful logout at the Identity Provider.")
|
|
flag.String(OpenIDProvider, string(ProviderOpenID), "Provider configuration to load and use, either 'openid', 'azure', 'idporten'.")
|
|
flag.String(OpenIDResourceIndicator, "", "OAuth2 resource indicator to include in authorization request for acquiring audience-restricted tokens.")
|
|
flag.StringSlice(OpenIDScopes, []string{}, "List of additional scopes (other than 'openid') that should be used during the login flow.")
|
|
flag.String(OpenIDUILocales, "", "Space-separated string that configures the default UI locale (ui_locales) parameter for OAuth2 consent screen.")
|
|
flag.String(OpenIDWellKnownURL, "", "URI to the well-known OpenID Configuration metadata document.")
|
|
|
|
flag.Bool(OpenTelemetryEnabled, false, "Enable OpenTelemetry tracing.")
|
|
flag.String(OpenTelemetryServiceName, "wonderwall", "Service name to use for OpenTelemetry.")
|
|
|
|
flag.String(RedisURI, "", "Redis URI string. Prefer using this. An empty value will fall back to 'redis-address'.")
|
|
flag.String(RedisAddress, "", "Address of the Redis instance (host:port). An empty value will use in-memory session storage. Does not override address set by 'redis.uri'.")
|
|
flag.String(RedisPassword, "", "Password for Redis. Overrides password set by 'redis.uri'.")
|
|
flag.Bool(RedisTLS, true, "Whether or not to use TLS for connecting to Redis. Does not override TLS config set by 'redis.uri'.")
|
|
flag.String(RedisUsername, "", "Username for Redis. Overrides username set by 'redis.uri'.")
|
|
flag.Int(RedisConnectionIdleTimeout, 0, "Idle timeout for Redis connections, in seconds. If non-zero, the value should be less than the client timeout configured at the Redis server. A value of -1 disables timeout. If zero, the default value from go-redis is used (30 minutes). Overrides options set by 'redis.uri'.")
|
|
|
|
flag.Bool(SessionInactivity, false, "Automatically expire user sessions if they have not refreshed their tokens within a given duration.")
|
|
flag.Duration(SessionInactivityTimeout, 30*time.Minute, "Inactivity timeout for user sessions.")
|
|
flag.Duration(SessionMaxLifetime, 10*time.Hour, "Max lifetime for user sessions.")
|
|
flag.Bool(SessionRefresh, true, "Enable refresh tokens.")
|
|
flag.Bool(SessionRefreshAuto, true, "Enable automatic refresh of tokens. Only available in standalone mode. Will automatically refresh tokens if they are expired as long as the session is valid (i.e. not exceeding 'session.max-lifetime' or 'session.inactivity-timeout').")
|
|
|
|
flag.Bool(SSOEnabled, false, "Enable single sign-on mode; one server acting as the OIDC Relying Party, and N proxies. The proxies delegate most endpoint operations to the server, and only implements a reverse proxy that reads the user's session data from the shared store.")
|
|
flag.String(SSODomain, "", "The domain that the session cookies should be set for, usually the second-level domain name (e.g. example.com).")
|
|
flag.String(SSOModeFlag, string(SSOModeServer), "The SSO mode for this instance. Must be one of 'server' or 'proxy'.")
|
|
flag.String(SSOSessionCookieName, "", "Session cookie name. Must be the same across all SSO Servers and Proxies.")
|
|
flag.String(SSOServerDefaultRedirectURL, "", "The URL that the SSO server should redirect to by default if a given redirect query parameter is invalid.")
|
|
flag.String(SSOServerURL, "", "The URL used by the proxy to point to the SSO server instance.")
|
|
|
|
flag.Parse()
|
|
|
|
if err := viper.ReadInConfig(); err != nil {
|
|
if err.(viper.ConfigFileNotFoundError) != err {
|
|
return nil, err
|
|
}
|
|
}
|
|
if err := viper.BindPFlags(flag.CommandLine); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch Provider(viper.GetString(OpenIDProvider)) {
|
|
case ProviderIDPorten:
|
|
viper.BindEnv(OpenIDClientID, "IDPORTEN_CLIENT_ID")
|
|
viper.BindEnv(OpenIDClientJWK, "IDPORTEN_CLIENT_JWK")
|
|
viper.BindEnv(OpenIDWellKnownURL, "IDPORTEN_WELL_KNOWN_URL")
|
|
|
|
viper.SetDefault(OpenIDACRValues, acr.IDPortenLevel4) // TODO - change to new value after migration
|
|
viper.SetDefault(OpenIDUILocales, "nb")
|
|
case ProviderAzure:
|
|
viper.BindEnv(OpenIDClientID, "AZURE_APP_CLIENT_ID")
|
|
viper.BindEnv(OpenIDClientJWK, "AZURE_APP_JWK")
|
|
viper.BindEnv(OpenIDWellKnownURL, "AZURE_APP_WELL_KNOWN_URL")
|
|
default:
|
|
viper.Set(OpenIDProvider, ProviderOpenID)
|
|
}
|
|
|
|
viper.Set("version", version())
|
|
|
|
cfg := new(Config)
|
|
err := viper.UnmarshalExact(cfg, func(dc *mapstructure.DecoderConfig) {
|
|
dc.TagName = "json"
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := logging.Setup(cfg.LogLevel, cfg.LogFormat); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
log.Tracef("Trace logging enabled")
|
|
|
|
maskedConfig := []string{
|
|
OpenIDClientJWK,
|
|
EncryptionKey,
|
|
RedisPassword,
|
|
}
|
|
|
|
for _, line := range conftools.Format(maskedConfig) {
|
|
log.WithField("logger", "wonderwall.config").Info(line)
|
|
}
|
|
|
|
err = cfg.Validate()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("validating config: %w", err)
|
|
}
|
|
|
|
cfg.upstreamHostOverride()
|
|
|
|
return cfg, nil
|
|
}
|
|
|
|
func (c *Config) Validate() error {
|
|
if c.Session.Inactivity && !c.Session.Refresh {
|
|
return fmt.Errorf("%q cannot be enabled without %q", SessionInactivity, SessionRefresh)
|
|
}
|
|
|
|
if c.Session.RefreshAuto && !c.Session.Refresh {
|
|
return fmt.Errorf("%q cannot be enabled without %q", SessionRefreshAuto, SessionRefresh)
|
|
}
|
|
|
|
if c.SSO.Enabled {
|
|
if len(c.Redis.Address) == 0 && len(c.Redis.URI) == 0 {
|
|
return fmt.Errorf("at least one of %q or %q must be set when %s is set", RedisAddress, RedisURI, SSOEnabled)
|
|
}
|
|
|
|
if len(c.SSO.SessionCookieName) == 0 {
|
|
return fmt.Errorf("%q must not be empty when %s is set", SSOSessionCookieName, SSOEnabled)
|
|
}
|
|
|
|
if c.Session.RefreshAuto {
|
|
return fmt.Errorf("%q cannot be enabled when %q is enabled", SessionRefreshAuto, SSOEnabled)
|
|
}
|
|
|
|
switch c.SSO.Mode {
|
|
case SSOModeProxy:
|
|
_, err := url.ParseRequestURI(c.SSO.ServerURL)
|
|
if err != nil {
|
|
return fmt.Errorf("%q must be a valid url: %w", SSOServerURL, err)
|
|
}
|
|
case SSOModeServer:
|
|
if len(c.SSO.Domain) == 0 {
|
|
return fmt.Errorf("%q cannot be empty", SSODomain)
|
|
}
|
|
|
|
_, err := url.ParseRequestURI(c.SSO.ServerDefaultRedirectURL)
|
|
if err != nil {
|
|
return fmt.Errorf("%q must be a valid url: %w", SSOServerDefaultRedirectURL, err)
|
|
}
|
|
default:
|
|
return fmt.Errorf("%q must be one of [%q, %q]", SSOModeFlag, SSOModeServer, SSOModeProxy)
|
|
}
|
|
}
|
|
|
|
if c.upstreamPortSet() {
|
|
if !c.upstreamIpSet() {
|
|
return fmt.Errorf("%q must be set when %q is set (was '%d')", UpstreamIP, UpstreamPort, c.UpstreamPort)
|
|
}
|
|
if !c.upstreamPortValid() {
|
|
return fmt.Errorf("%q must be in valid range (between '1' and '65535', was '%d')", UpstreamPort, c.UpstreamPort)
|
|
}
|
|
}
|
|
|
|
if c.upstreamIpSet() && !c.upstreamPortSet() {
|
|
return fmt.Errorf("%q must be set when %q is set (was %q)", UpstreamPort, UpstreamIP, c.UpstreamIP)
|
|
}
|
|
|
|
if c.ShutdownGracefulPeriod <= c.ShutdownWaitBeforePeriod {
|
|
return fmt.Errorf("%q must be greater than %q", ShutdownGracefulPeriod, ShutdownWaitBeforePeriod)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Config) upstreamIpSet() bool {
|
|
return c.UpstreamIP != ""
|
|
}
|
|
|
|
func (c *Config) upstreamPortSet() bool {
|
|
return c.UpstreamPort != 0
|
|
}
|
|
|
|
func (c *Config) upstreamPortValid() bool {
|
|
return c.UpstreamPort >= 1 && c.UpstreamPort <= 65535
|
|
}
|
|
|
|
func (c *Config) upstreamHostOverride() {
|
|
if c.upstreamIpSet() && c.upstreamPortSet() && c.upstreamPortValid() {
|
|
override := fmt.Sprintf("%s:%d", c.UpstreamIP, c.UpstreamPort)
|
|
|
|
log.WithField("logger", "wonderwall.config").
|
|
Infof("%q and %q were set; overriding %q from %q to %q", UpstreamHost, UpstreamPort, UpstreamHost, c.UpstreamHost, override)
|
|
|
|
c.UpstreamHost = override
|
|
}
|
|
}
|
|
|
|
func version() string {
|
|
info, ok := debug.ReadBuildInfo()
|
|
if !ok {
|
|
return ""
|
|
}
|
|
|
|
var rev string
|
|
var last time.Time
|
|
|
|
for _, kv := range info.Settings {
|
|
switch kv.Key {
|
|
case "vcs.revision":
|
|
rev = kv.Value
|
|
case "vcs.time":
|
|
last, _ = time.Parse(time.RFC3339, kv.Value)
|
|
}
|
|
}
|
|
|
|
if len(rev) > 7 {
|
|
rev = rev[:7]
|
|
}
|
|
|
|
return fmt.Sprintf("%s-%s", last.Format("2006-01-02-150405"), rev)
|
|
}
|