diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index 6a37f36..9b4bcf6 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -100,9 +100,7 @@ func run() error { return fmt.Errorf("initializing routing handler: %w", err) } - prefixes := config.ParseIngresses(cfg.Ingresses) - - r := router.New(handler, prefixes) + r := router.New(handler) go func() { err := metrics.Handle(cfg.MetricsBindAddress) diff --git a/pkg/config/config.go b/pkg/config/config.go index 48f15df..af3c01e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,11 +1,13 @@ package config import ( + "time" + "github.com/nais/liberator/pkg/conftools" - "github.com/nais/wonderwall/pkg/token" flag "github.com/spf13/pflag" "github.com/spf13/viper" - "time" + + "github.com/nais/wonderwall/pkg/token" ) type Config struct { @@ -17,7 +19,7 @@ type Config struct { LogFormat string `json:"log-format"` LogLevel string `json:"log-level"` Redis string `json:"redis"` - Ingresses []string `json:"ingresses"` + Ingress string `json:"ingress"` ErrorRedirectURI string `json:"error-redirect-uri"` } @@ -52,7 +54,8 @@ const ( LogLevel = "log-level" EncryptionKey = "encryption-key" Redis = "redis" - Ingresses = "ingresses" + Ingress = "ingress" + ErrorRedirectURI = "error-redirect-uri" IDPortenClientID = "idporten.client-id" IDPortenClientJWK = "idporten.client-jwk" IDPortenRedirectURI = "idporten.redirect-uri" @@ -64,7 +67,6 @@ const ( IDPortenPostLogoutRedirectURI = "idporten.post-logout-redirect-uri" IDPortenScopes = "idporten.scopes" IDPortenSessionMaxLifetime = "idporten.session-max-lifetime" - ErrorRedirectURI = "error-redirect-uri" ) func bindNAIS() { @@ -85,6 +87,9 @@ func Initialize() *Config { flag.String(UpstreamHost, "127.0.0.1:8080", "Address of upstream host.") flag.String(EncryptionKey, "", "Base64 encoded 256-bit cookie encryption key; must be identical in instances that share session store.") flag.String(Redis, "", "Address of Redis. An empty value will use in-memory session storage.") + flag.String(Ingress, "/", "Ingress used to access the main application.") + flag.String(ErrorRedirectURI, "", "URI to redirect user to on errors for custom error handling.") + flag.Bool(IDPortenSecurityLevelEnabled, true, "Toggle for setting the sceurity level (acr_values) parameter for authorization requests.") flag.String(IDPortenSecurityLevelValue, "Level4", "Requested security level, either Level3 or Level4.") flag.Bool(IDPortenLocaleEnabled, true, "Toggle for setting the locale parameter for authorization requests.") @@ -92,8 +97,6 @@ func Initialize() *Config { flag.String(IDPortenPostLogoutRedirectURI, "https://www.nav.no", "URI for redirecting the user after successful logout at IDPorten.") flag.StringSlice(IDPortenScopes, []string{token.ScopeOpenID}, "List of scopes that should be used during the Auth Code flow.") flag.Duration(IDPortenSessionMaxLifetime, time.Hour, "Max lifetime for user sessions.") - flag.StringSlice(Ingresses, []string{"/"}, "Ingresses used to access the main application.") - flag.String(ErrorRedirectURI, "", "URI to redirect user to on errors for custom error handling.") return &Config{} } diff --git a/pkg/config/ingress.go b/pkg/config/ingress.go index a0767db..28503cf 100644 --- a/pkg/config/ingress.go +++ b/pkg/config/ingress.go @@ -5,24 +5,13 @@ import ( "strings" ) -func ParseIngresses(ingresses []string) []string { - prefixMap := make(map[string]interface{}) - - for _, ingress := range ingresses { - ingressURL, err := url.Parse(ingress) - if err != nil { - continue - } - path := ingressURL.Path - path = strings.TrimRight(path, "/") - - prefixMap[path] = new(interface{}) +func ParseIngress(ingress string) string { + ingressURL, err := url.Parse(ingress) + if err != nil { + return "" } + path := ingressURL.Path + path = strings.TrimRight(path, "/") - prefixes := make([]string, 0) - for prefix := range prefixMap { - prefixes = append(prefixes, prefix) - } - - return prefixes + return path } diff --git a/pkg/config/ingress_test.go b/pkg/config/ingress_test.go index 69d3a7c..940ed08 100644 --- a/pkg/config/ingress_test.go +++ b/pkg/config/ingress_test.go @@ -1,7 +1,6 @@ package config_test import ( - "sort" "testing" "github.com/stretchr/testify/assert" @@ -9,27 +8,52 @@ import ( "github.com/nais/wonderwall/pkg/config" ) -func TestParseIngresses(t *testing.T) { - ingresses := []string{"https://tjenester.nav.no/sykepenger/", "https://sykepenger.nav.no/", "https://sykepenger-test.nav.no"} - expected := []string{"", "/sykepenger"} +func TestParseIngress(t *testing.T) { + for _, test := range []struct{ + ingress string + want string + }{ + { + ingress: "https://tjenester.nav.no/sykepenger/", + want: "/sykepenger", + }, + { + ingress: "https://tjenester.nav.no/sykepenger/test", + want: "/sykepenger/test", + }, + { + ingress: "https://tjenester.nav.no/test/sykepenger/", + want: "/test/sykepenger", + }, + { + ingress: "https://sykepenger.nav.no/", + want: "", + }, + { + ingress: "https://sykepenger-test.nav.no", + want: "", + }, - prefixes := config.ParseIngresses(ingresses) - sort.Strings(prefixes) - assert.Equal(t, expected, prefixes) + } { + t.Run(test.ingress, func(t *testing.T) { + prefix := config.ParseIngress(test.ingress) + assert.Equal(t, test.want, prefix) + }) + } } func TestParseEmptyIngress(t *testing.T) { - ingresses := []string{""} - expected := []string{""} + ingress := "" + expected := "" - prefixes := config.ParseIngresses(ingresses) - assert.Equal(t, expected, prefixes) + prefix := config.ParseIngress(ingress) + assert.Equal(t, expected, prefix) } func TestParseDefaultIngress(t *testing.T) { - ingresses := []string{"/"} - expected := []string{""} + ingress := "/" + expected := "" - prefixes := config.ParseIngresses(ingresses) - assert.Equal(t, expected, prefixes) + prefix := config.ParseIngress(ingress) + assert.Equal(t, expected, prefix) } diff --git a/pkg/router/router.go b/pkg/router/router.go index 2de960b..9903364 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -4,26 +4,27 @@ import ( "github.com/go-chi/chi/v5" chi_middleware "github.com/go-chi/chi/v5/middleware" + "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/middleware" ) -func New(handler *Handler, prefixes []string) chi.Router { +func New(handler *Handler) chi.Router { r := chi.NewRouter() r.Use(middleware.CorrelationIDHandler) r.Use(middleware.LogEntryHandler(handler.httplogger)) r.Use(chi_middleware.Recoverer) prometheusMiddleware := middleware.NewPrometheusMiddleware("wonderwall") - for _, prefix := range prefixes { - r.Route(prefix+"/oauth2", func(r chi.Router) { - r.Use(prometheusMiddleware.Handler) - r.Use(chi_middleware.NoCache) - r.Get("/login", handler.Login) - r.Get("/callback", handler.Callback) - r.Get("/logout", handler.Logout) - r.Get("/logout/frontchannel", handler.FrontChannelLogout) - }) - } + prefix := config.ParseIngress(handler.Config.Ingress) + + r.Route(prefix+"/oauth2", func(r chi.Router) { + r.Use(prometheusMiddleware.Handler) + r.Use(chi_middleware.NoCache) + r.Get("/login", handler.Login) + r.Get("/callback", handler.Callback) + r.Get("/logout", handler.Logout) + r.Get("/logout/frontchannel", handler.FrontChannelLogout) + }) r.HandleFunc("/*", handler.Default) return r } diff --git a/pkg/router/router_test.go b/pkg/router/router_test.go index d964440..c7f0b50 100644 --- a/pkg/router/router_test.go +++ b/pkg/router/router_test.go @@ -100,8 +100,7 @@ func TestHandler_Login(t *testing.T) { cfg := defaultConfig() h := handler(cfg) - prefixes := config.ParseIngresses([]string{""}) - r := router.New(h, prefixes) + r := router.New(h) jar, err := cookiejar.New(nil) assert.NoError(t, err) @@ -166,8 +165,7 @@ func TestHandler_Callback_and_Logout(t *testing.T) { cfg.IDPorten.WellKnown.EndSessionEndpoint = idpserver.URL + "/endsession" h := handler(cfg) - prefixes := config.ParseIngresses([]string{""}) - r := router.New(h, prefixes) + r := router.New(h) server := httptest.NewServer(r) h.Config.IDPorten.RedirectURI = server.URL + "/oauth2/callback" @@ -262,8 +260,7 @@ func TestHandler_FrontChannelLogout(t *testing.T) { cfg.IDPorten.WellKnown.TokenEndpoint = idpserver.URL + "/token" h := handler(cfg) - prefixes := config.ParseIngresses([]string{""}) - r := router.New(h, prefixes) + r := router.New(h) server := httptest.NewServer(r) h.Config.IDPorten.RedirectURI = server.URL + "/oauth2/callback"