diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index e799248..4c7e49d 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -58,7 +58,7 @@ func run() error { crypt := crypto.NewCrypter(key) sessionStore := session.NewStore(cfg) httplogger := logging.NewHttpLogger(cfg) - h, err := router.NewHandler(jwksRefreshCtx, cfg, crypt, httplogger, openidConfig, sessionStore) + h, err := router.NewHandler(jwksRefreshCtx, openidConfig, crypt, httplogger, sessionStore) if err != nil { return fmt.Errorf("initializing routing handler: %w", err) } diff --git a/pkg/mock/config.go b/pkg/mock/config.go index 324defb..d0594bc 100644 --- a/pkg/mock/config.go +++ b/pkg/mock/config.go @@ -20,10 +20,9 @@ func Config() *config.Config { } type Configuration struct { - ClientConfig *TestClientConfiguration - ProviderConfig *openidconfig.Provider - IngressConfig string - LoginstatusConfig config.Loginstatus + ClientConfig *TestClientConfiguration + ProviderConfig *openidconfig.Provider + WonderwallConfig *config.Config } func (c Configuration) Client() openidconfig.Client { @@ -34,19 +33,18 @@ func (c Configuration) Provider() *openidconfig.Provider { return c.ProviderConfig } -func (c Configuration) Ingress() string { - return c.IngressConfig +func (c Configuration) ProviderName() string { + return string(c.WonderwallConfig.OpenID.Provider) } -func (c Configuration) Loginstatus() config.Loginstatus { - return c.LoginstatusConfig +func (c Configuration) Wonderwall() *config.Config { + return c.WonderwallConfig } func NewTestConfiguration(cfg *config.Config) Configuration { return Configuration{ - ClientConfig: clientConfiguration(cfg), - ProviderConfig: providerConfiguration(), - IngressConfig: cfg.Ingress, - LoginstatusConfig: cfg.Loginstatus, + ClientConfig: clientConfiguration(cfg), + ProviderConfig: providerConfiguration(), + WonderwallConfig: cfg, } } diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index cd1e4bd..9244291 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -72,7 +72,7 @@ func NewIdentityProvider(cfg *config.Config) IdentityProvider { sessionStore := session.NewMemory() ctx, cancel := context.WithCancel(context.Background()) - rpHandler, err := router.NewHandler(ctx, cfg, crypter, zerolog.Nop(), openidConfig, sessionStore) + rpHandler, err := router.NewHandler(ctx, openidConfig, crypter, zerolog.Nop(), sessionStore) if err != nil { panic(err) } diff --git a/pkg/openid/client/login.go b/pkg/openid/client/login.go index 548df9f..5ce5421 100644 --- a/pkg/openid/client/login.go +++ b/pkg/openid/client/login.go @@ -53,7 +53,7 @@ func NewLogin(c Client, r *http.Request) (Login, error) { return nil, fmt.Errorf("generating login url: %w", err) } - redirect := request.CanonicalRedirectURL(r, c.config().Ingress()) + redirect := request.CanonicalRedirectURL(r, c.config().Wonderwall().Ingress) cookie := params.cookie(redirect) return login{ @@ -142,8 +142,8 @@ func (in *loginParameters) authCodeURL(r *http.Request) (string, error) { oauth2.SetAuthURLParam("code_challenge_method", "S256"), } - if in.config().Loginstatus().NeedsResourceIndicator() { - opts = append(opts, oauth2.SetAuthURLParam("resource", in.config().Loginstatus().ResourceIndicator)) + if in.config().Wonderwall().Loginstatus.NeedsResourceIndicator() { + opts = append(opts, oauth2.SetAuthURLParam("resource", in.config().Wonderwall().Loginstatus.ResourceIndicator)) } opts, err := in.withSecurityLevel(r, opts) diff --git a/pkg/openid/config/config.go b/pkg/openid/config/config.go index 078b9ec..20aee1c 100644 --- a/pkg/openid/config/config.go +++ b/pkg/openid/config/config.go @@ -7,16 +7,15 @@ import ( type Config interface { Client() Client Provider() *Provider + ProviderName() string - Ingress() string - Loginstatus() wonderwallconfig.Loginstatus + Wonderwall() *wonderwallconfig.Config } type config struct { + cfg *wonderwallconfig.Config clientConfig Client providerConfig *Provider - ingress string - loginstatus wonderwallconfig.Loginstatus } func (c config) Client() Client { @@ -27,12 +26,12 @@ func (c config) Provider() *Provider { return c.providerConfig } -func (c config) Ingress() string { - return c.ingress +func (c config) ProviderName() string { + return string(c.cfg.OpenID.Provider) } -func (c config) Loginstatus() wonderwallconfig.Loginstatus { - return c.loginstatus +func (c config) Wonderwall() *wonderwallconfig.Config { + return c.cfg } func NewConfig(cfg *wonderwallconfig.Config) (Config, error) { @@ -47,9 +46,8 @@ func NewConfig(cfg *wonderwallconfig.Config) (Config, error) { } return config{ + cfg: cfg, clientConfig: clientCfg, providerConfig: providerCfg, - ingress: cfg.Ingress, - loginstatus: cfg.Loginstatus, }, nil } diff --git a/pkg/router/handler.go b/pkg/router/handler.go index 4dbf70b..6ea16ac 100644 --- a/pkg/router/handler.go +++ b/pkg/router/handler.go @@ -17,12 +17,11 @@ import ( ) type Handler struct { + Cfg openidconfig.Config Client client.Client - Config *config.Config CookieOptions cookie.Options Crypter crypto.Crypter Loginstatus loginstatus.Client - OpenIDConfig openidconfig.Config Provider provider.Provider Sessions session.Store Httplogger zerolog.Logger @@ -30,35 +29,33 @@ type Handler struct { func NewHandler( jwksRefreshCtx context.Context, - cfg *config.Config, + cfg openidconfig.Config, crypter crypto.Crypter, httplogger zerolog.Logger, - openidConfig openidconfig.Config, sessionStore session.Store, ) (*Handler, error) { - loginstatusClient := loginstatus.NewClient(cfg.Loginstatus, http.DefaultClient) + loginstatusClient := loginstatus.NewClient(cfg.Wonderwall().Loginstatus, http.DefaultClient) - cookiePath := config.ParseIngress(cfg.Ingress) + cookiePath := config.ParseIngress(cfg.Wonderwall().Ingress) cookieOpts := cookie.DefaultOptions().WithPath(cookiePath) - openidProvider, err := provider.NewProvider(jwksRefreshCtx, openidConfig) + openidProvider, err := provider.NewProvider(jwksRefreshCtx, cfg) if err != nil { return nil, err } - openidClient := client.NewClient(openidConfig) + openidClient := client.NewClient(cfg) if err != nil { return nil, err } return &Handler{ Client: openidClient, - Config: cfg, CookieOptions: cookieOpts, Crypter: crypter, Httplogger: httplogger, Loginstatus: loginstatusClient, - OpenIDConfig: openidConfig, + Cfg: cfg, Provider: openidProvider, Sessions: sessionStore, }, nil diff --git a/pkg/router/handler_callback.go b/pkg/router/handler_callback.go index f8d4aa3..765a855 100644 --- a/pkg/router/handler_callback.go +++ b/pkg/router/handler_callback.go @@ -71,7 +71,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - err = tokens.IDToken.Validate(h.OpenIDConfig, loginCookie.Nonce) + err = tokens.IDToken.Validate(h.Cfg, loginCookie.Nonce) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: validating id_token: %w", err)) return @@ -83,7 +83,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - if h.Config.Loginstatus.Enabled { + if h.Cfg.Wonderwall().Loginstatus.Enabled { tokenResponse, err := h.getLoginstatusToken(r.Context(), tokens) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: exchanging loginstatus token: %w", err)) diff --git a/pkg/router/handler_default.go b/pkg/router/handler_default.go index 7504765..bf47a7b 100644 --- a/pkg/router/handler_default.go +++ b/pkg/router/handler_default.go @@ -22,20 +22,20 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { isAuthenticated = true // force new authentication if loginstatus is enabled and cookie isn't set - if h.Config.Loginstatus.Enabled && !h.Loginstatus.HasCookie(r) { + if h.Cfg.Wonderwall().Loginstatus.Enabled && !h.Loginstatus.HasCookie(r) { isAuthenticated = false log.Info("default: loginstatus was enabled, but no matching cookie was found; state is now unauthenticated") } } - if !isAuthenticated && h.Config.AutoLogin { + if !isAuthenticated && h.Cfg.Wonderwall().AutoLogin { r.Header.Add("Referer", r.URL.String()) h.Login(w, r) return } director := func(upstreamRequest *http.Request) { - modifyRequest(upstreamRequest, r, h.Config.UpstreamHost) + modifyRequest(upstreamRequest, r, h.Cfg.Wonderwall().UpstreamHost) if isAuthenticated { withAuthentication(upstreamRequest, sessionData) diff --git a/pkg/router/handler_error.go b/pkg/router/handler_error.go index 8a122f2..1425b51 100644 --- a/pkg/router/handler_error.go +++ b/pkg/router/handler_error.go @@ -43,7 +43,7 @@ func (h *Handler) respondError(w http.ResponseWriter, r *http.Request, statusCod logger := logentry.LogEntry(r.Context()) logger.WithLevel(level).Stack().Err(cause).Msgf("error in route: %+v", cause) - if len(h.Config.ErrorRedirectURI) > 0 { + if len(h.Cfg.Wonderwall().ErrorRedirectURI) > 0 { err := h.customErrorRedirect(w, r, statusCode) if err == nil { return @@ -63,7 +63,7 @@ func (h *Handler) defaultErrorResponse(w http.ResponseWriter, r *http.Request, s errorPage := ErrorPage{ CorrelationID: middleware.GetReqID(r.Context()), - RetryURI: RetryURI(r, h.Config.Ingress, loginCookie), + RetryURI: RetryURI(r, h.Cfg.Wonderwall().Ingress, loginCookie), } err = errorTemplate.Execute(w, errorPage) if err != nil { @@ -72,7 +72,7 @@ func (h *Handler) defaultErrorResponse(w http.ResponseWriter, r *http.Request, s } func (h *Handler) customErrorRedirect(w http.ResponseWriter, r *http.Request, statusCode int) error { - override, err := url.Parse(h.Config.ErrorRedirectURI) + override, err := url.Parse(h.Cfg.Wonderwall().ErrorRedirectURI) if err != nil { return err } diff --git a/pkg/router/handler_frontchannellogout.go b/pkg/router/handler_frontchannellogout.go index ee286ed..e663a4b 100644 --- a/pkg/router/handler_frontchannellogout.go +++ b/pkg/router/handler_frontchannellogout.go @@ -16,7 +16,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { // Unconditionally destroy all local references to the session. cookie.Clear(w, cookie.Session, h.CookieOptions) - if h.Config.Loginstatus.Enabled { + if h.Cfg.Wonderwall().Loginstatus.Enabled { h.Loginstatus.ClearCookie(w, h.CookieOptions) } diff --git a/pkg/router/handler_logout.go b/pkg/router/handler_logout.go index 71653d9..ed37a9c 100644 --- a/pkg/router/handler_logout.go +++ b/pkg/router/handler_logout.go @@ -37,11 +37,11 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { cookie.Clear(w, cookie.Session, h.CookieOptions) - if h.Config.Loginstatus.Enabled { + if h.Cfg.Wonderwall().Loginstatus.Enabled { h.Loginstatus.ClearCookie(w, h.CookieOptions) } - u, err := url.Parse(h.OpenIDConfig.Provider().EndSessionEndpoint) + u, err := url.Parse(h.Cfg.Provider().EndSessionEndpoint) if err != nil { h.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err)) return @@ -60,7 +60,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { } v := u.Query() - v.Add("post_logout_redirect_uri", h.OpenIDConfig.Client().GetLogoutCallbackURI()) + v.Add("post_logout_redirect_uri", h.Cfg.Client().GetLogoutCallbackURI()) v.Add("state", logoutCookie.State) if len(idToken) > 0 { @@ -86,7 +86,7 @@ func (h *Handler) logoutCookie() (*openid.LogoutCookie, error) { return &openid.LogoutCookie{ State: state, - RedirectTo: h.OpenIDConfig.Client().GetPostLogoutRedirectURI(), + RedirectTo: h.Cfg.Client().GetPostLogoutRedirectURI(), }, nil } diff --git a/pkg/router/handler_logout_callback.go b/pkg/router/handler_logout_callback.go index bd3bb40..275d69f 100644 --- a/pkg/router/handler_logout_callback.go +++ b/pkg/router/handler_logout_callback.go @@ -24,7 +24,7 @@ func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) { logoutCookie, err := h.getLogoutCookie(r) if err != nil { logger.Warn().Msgf("logout/callback: getting cookie: %+v", err) - http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect) + http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect) return } @@ -34,13 +34,13 @@ func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) { if expectedState != actualState { logger.Warn().Msgf("logout/callback: state parameter mismatch: expected %s, got %s; falling back to ingress", expectedState, actualState) - http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect) + http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect) return } if len(logoutCookie.RedirectTo) == 0 { logger.Warn().Msgf("logout/callback: empty redirect; falling back to ingress") - http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect) + http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect) return } diff --git a/pkg/router/router.go b/pkg/router/router.go index 6f009d6..4e4cc5b 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -13,9 +13,9 @@ func New(handler *Handler) chi.Router { r := chi.NewRouter() r.Use(middleware.CorrelationIDHandler) r.Use(chi_middleware.Recoverer) - prometheusMiddleware := middleware.NewPrometheusMiddleware("wonderwall", string(handler.Config.OpenID.Provider)) + prometheusMiddleware := middleware.NewPrometheusMiddleware("wonderwall", handler.Cfg.ProviderName()) - prefix := config.ParseIngress(handler.Config.Ingress) + prefix := config.ParseIngress(handler.Cfg.Wonderwall().Ingress) r.Route(prefix+paths.OAuth2, func(r chi.Router) { r.Use(middleware.LogEntryHandler(handler.Httplogger)) diff --git a/pkg/router/session.go b/pkg/router/session.go index c573342..b0c232e 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -23,7 +23,7 @@ import ( // Thus, we cannot assume that the value of `sid` or `session_state` to uniquely identify the pair of (user, application session) // if using a shared session store. func (h *Handler) localSessionID(sessionID string) string { - return fmt.Sprintf("%s:%s:%s", h.Config.OpenID.Provider, h.OpenIDConfig.Client().GetClientID(), sessionID) + return fmt.Sprintf("%s:%s:%s", h.Cfg.ProviderName(), h.Cfg.Client().GetClientID(), sessionID) } func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (*session.Data, error) { @@ -67,7 +67,7 @@ func (h *Handler) getSession(ctx context.Context, sessionID string) (*session.Da } func (h *Handler) getSessionLifetime(tokenExpiry time.Time) time.Duration { - defaultSessionLifetime := h.Config.SessionMaxLifetime + defaultSessionLifetime := h.Cfg.Wonderwall().SessionMaxLifetime tokenDuration := tokenExpiry.Sub(time.Now()) @@ -79,7 +79,7 @@ func (h *Handler) getSessionLifetime(tokenExpiry time.Time) time.Duration { } func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *jwt.Tokens, rawTokens *oauth2.Token, params url.Values) error { - externalSessionID, err := session.NewSessionID(h.OpenIDConfig.Provider(), tokens.IDToken, params) + externalSessionID, err := session.NewSessionID(h.Cfg.Provider(), tokens.IDToken, params) if err != nil { return fmt.Errorf("generating session ID: %w", err) }