refactor(handler): deduplicate configuration

This commit is contained in:
Trong Huu Nguyen
2022-07-05 14:43:40 +02:00
parent a4c3e72fc9
commit 42938ee8b3
14 changed files with 51 additions and 58 deletions

View File

@@ -58,7 +58,7 @@ func run() error {
crypt := crypto.NewCrypter(key)
sessionStore := session.NewStore(cfg)
httplogger := logging.NewHttpLogger(cfg)
h, err := router.NewHandler(jwksRefreshCtx, cfg, crypt, httplogger, openidConfig, sessionStore)
h, err := router.NewHandler(jwksRefreshCtx, openidConfig, crypt, httplogger, sessionStore)
if err != nil {
return fmt.Errorf("initializing routing handler: %w", err)
}

View File

@@ -20,10 +20,9 @@ func Config() *config.Config {
}
type Configuration struct {
ClientConfig *TestClientConfiguration
ProviderConfig *openidconfig.Provider
IngressConfig string
LoginstatusConfig config.Loginstatus
ClientConfig *TestClientConfiguration
ProviderConfig *openidconfig.Provider
WonderwallConfig *config.Config
}
func (c Configuration) Client() openidconfig.Client {
@@ -34,19 +33,18 @@ func (c Configuration) Provider() *openidconfig.Provider {
return c.ProviderConfig
}
func (c Configuration) Ingress() string {
return c.IngressConfig
func (c Configuration) ProviderName() string {
return string(c.WonderwallConfig.OpenID.Provider)
}
func (c Configuration) Loginstatus() config.Loginstatus {
return c.LoginstatusConfig
func (c Configuration) Wonderwall() *config.Config {
return c.WonderwallConfig
}
func NewTestConfiguration(cfg *config.Config) Configuration {
return Configuration{
ClientConfig: clientConfiguration(cfg),
ProviderConfig: providerConfiguration(),
IngressConfig: cfg.Ingress,
LoginstatusConfig: cfg.Loginstatus,
ClientConfig: clientConfiguration(cfg),
ProviderConfig: providerConfiguration(),
WonderwallConfig: cfg,
}
}

View File

@@ -72,7 +72,7 @@ func NewIdentityProvider(cfg *config.Config) IdentityProvider {
sessionStore := session.NewMemory()
ctx, cancel := context.WithCancel(context.Background())
rpHandler, err := router.NewHandler(ctx, cfg, crypter, zerolog.Nop(), openidConfig, sessionStore)
rpHandler, err := router.NewHandler(ctx, openidConfig, crypter, zerolog.Nop(), sessionStore)
if err != nil {
panic(err)
}

View File

@@ -53,7 +53,7 @@ func NewLogin(c Client, r *http.Request) (Login, error) {
return nil, fmt.Errorf("generating login url: %w", err)
}
redirect := request.CanonicalRedirectURL(r, c.config().Ingress())
redirect := request.CanonicalRedirectURL(r, c.config().Wonderwall().Ingress)
cookie := params.cookie(redirect)
return login{
@@ -142,8 +142,8 @@ func (in *loginParameters) authCodeURL(r *http.Request) (string, error) {
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
}
if in.config().Loginstatus().NeedsResourceIndicator() {
opts = append(opts, oauth2.SetAuthURLParam("resource", in.config().Loginstatus().ResourceIndicator))
if in.config().Wonderwall().Loginstatus.NeedsResourceIndicator() {
opts = append(opts, oauth2.SetAuthURLParam("resource", in.config().Wonderwall().Loginstatus.ResourceIndicator))
}
opts, err := in.withSecurityLevel(r, opts)

View File

@@ -7,16 +7,15 @@ import (
type Config interface {
Client() Client
Provider() *Provider
ProviderName() string
Ingress() string
Loginstatus() wonderwallconfig.Loginstatus
Wonderwall() *wonderwallconfig.Config
}
type config struct {
cfg *wonderwallconfig.Config
clientConfig Client
providerConfig *Provider
ingress string
loginstatus wonderwallconfig.Loginstatus
}
func (c config) Client() Client {
@@ -27,12 +26,12 @@ func (c config) Provider() *Provider {
return c.providerConfig
}
func (c config) Ingress() string {
return c.ingress
func (c config) ProviderName() string {
return string(c.cfg.OpenID.Provider)
}
func (c config) Loginstatus() wonderwallconfig.Loginstatus {
return c.loginstatus
func (c config) Wonderwall() *wonderwallconfig.Config {
return c.cfg
}
func NewConfig(cfg *wonderwallconfig.Config) (Config, error) {
@@ -47,9 +46,8 @@ func NewConfig(cfg *wonderwallconfig.Config) (Config, error) {
}
return config{
cfg: cfg,
clientConfig: clientCfg,
providerConfig: providerCfg,
ingress: cfg.Ingress,
loginstatus: cfg.Loginstatus,
}, nil
}

View File

@@ -17,12 +17,11 @@ import (
)
type Handler struct {
Cfg openidconfig.Config
Client client.Client
Config *config.Config
CookieOptions cookie.Options
Crypter crypto.Crypter
Loginstatus loginstatus.Client
OpenIDConfig openidconfig.Config
Provider provider.Provider
Sessions session.Store
Httplogger zerolog.Logger
@@ -30,35 +29,33 @@ type Handler struct {
func NewHandler(
jwksRefreshCtx context.Context,
cfg *config.Config,
cfg openidconfig.Config,
crypter crypto.Crypter,
httplogger zerolog.Logger,
openidConfig openidconfig.Config,
sessionStore session.Store,
) (*Handler, error) {
loginstatusClient := loginstatus.NewClient(cfg.Loginstatus, http.DefaultClient)
loginstatusClient := loginstatus.NewClient(cfg.Wonderwall().Loginstatus, http.DefaultClient)
cookiePath := config.ParseIngress(cfg.Ingress)
cookiePath := config.ParseIngress(cfg.Wonderwall().Ingress)
cookieOpts := cookie.DefaultOptions().WithPath(cookiePath)
openidProvider, err := provider.NewProvider(jwksRefreshCtx, openidConfig)
openidProvider, err := provider.NewProvider(jwksRefreshCtx, cfg)
if err != nil {
return nil, err
}
openidClient := client.NewClient(openidConfig)
openidClient := client.NewClient(cfg)
if err != nil {
return nil, err
}
return &Handler{
Client: openidClient,
Config: cfg,
CookieOptions: cookieOpts,
Crypter: crypter,
Httplogger: httplogger,
Loginstatus: loginstatusClient,
OpenIDConfig: openidConfig,
Cfg: cfg,
Provider: openidProvider,
Sessions: sessionStore,
}, nil

View File

@@ -71,7 +71,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
err = tokens.IDToken.Validate(h.OpenIDConfig, loginCookie.Nonce)
err = tokens.IDToken.Validate(h.Cfg, loginCookie.Nonce)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: validating id_token: %w", err))
return
@@ -83,7 +83,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
if h.Config.Loginstatus.Enabled {
if h.Cfg.Wonderwall().Loginstatus.Enabled {
tokenResponse, err := h.getLoginstatusToken(r.Context(), tokens)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: exchanging loginstatus token: %w", err))

View File

@@ -22,20 +22,20 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
isAuthenticated = true
// force new authentication if loginstatus is enabled and cookie isn't set
if h.Config.Loginstatus.Enabled && !h.Loginstatus.HasCookie(r) {
if h.Cfg.Wonderwall().Loginstatus.Enabled && !h.Loginstatus.HasCookie(r) {
isAuthenticated = false
log.Info("default: loginstatus was enabled, but no matching cookie was found; state is now unauthenticated")
}
}
if !isAuthenticated && h.Config.AutoLogin {
if !isAuthenticated && h.Cfg.Wonderwall().AutoLogin {
r.Header.Add("Referer", r.URL.String())
h.Login(w, r)
return
}
director := func(upstreamRequest *http.Request) {
modifyRequest(upstreamRequest, r, h.Config.UpstreamHost)
modifyRequest(upstreamRequest, r, h.Cfg.Wonderwall().UpstreamHost)
if isAuthenticated {
withAuthentication(upstreamRequest, sessionData)

View File

@@ -43,7 +43,7 @@ func (h *Handler) respondError(w http.ResponseWriter, r *http.Request, statusCod
logger := logentry.LogEntry(r.Context())
logger.WithLevel(level).Stack().Err(cause).Msgf("error in route: %+v", cause)
if len(h.Config.ErrorRedirectURI) > 0 {
if len(h.Cfg.Wonderwall().ErrorRedirectURI) > 0 {
err := h.customErrorRedirect(w, r, statusCode)
if err == nil {
return
@@ -63,7 +63,7 @@ func (h *Handler) defaultErrorResponse(w http.ResponseWriter, r *http.Request, s
errorPage := ErrorPage{
CorrelationID: middleware.GetReqID(r.Context()),
RetryURI: RetryURI(r, h.Config.Ingress, loginCookie),
RetryURI: RetryURI(r, h.Cfg.Wonderwall().Ingress, loginCookie),
}
err = errorTemplate.Execute(w, errorPage)
if err != nil {
@@ -72,7 +72,7 @@ func (h *Handler) defaultErrorResponse(w http.ResponseWriter, r *http.Request, s
}
func (h *Handler) customErrorRedirect(w http.ResponseWriter, r *http.Request, statusCode int) error {
override, err := url.Parse(h.Config.ErrorRedirectURI)
override, err := url.Parse(h.Cfg.Wonderwall().ErrorRedirectURI)
if err != nil {
return err
}

View File

@@ -16,7 +16,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
// Unconditionally destroy all local references to the session.
cookie.Clear(w, cookie.Session, h.CookieOptions)
if h.Config.Loginstatus.Enabled {
if h.Cfg.Wonderwall().Loginstatus.Enabled {
h.Loginstatus.ClearCookie(w, h.CookieOptions)
}

View File

@@ -37,11 +37,11 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
cookie.Clear(w, cookie.Session, h.CookieOptions)
if h.Config.Loginstatus.Enabled {
if h.Cfg.Wonderwall().Loginstatus.Enabled {
h.Loginstatus.ClearCookie(w, h.CookieOptions)
}
u, err := url.Parse(h.OpenIDConfig.Provider().EndSessionEndpoint)
u, err := url.Parse(h.Cfg.Provider().EndSessionEndpoint)
if err != nil {
h.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err))
return
@@ -60,7 +60,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
}
v := u.Query()
v.Add("post_logout_redirect_uri", h.OpenIDConfig.Client().GetLogoutCallbackURI())
v.Add("post_logout_redirect_uri", h.Cfg.Client().GetLogoutCallbackURI())
v.Add("state", logoutCookie.State)
if len(idToken) > 0 {
@@ -86,7 +86,7 @@ func (h *Handler) logoutCookie() (*openid.LogoutCookie, error) {
return &openid.LogoutCookie{
State: state,
RedirectTo: h.OpenIDConfig.Client().GetPostLogoutRedirectURI(),
RedirectTo: h.Cfg.Client().GetPostLogoutRedirectURI(),
}, nil
}

View File

@@ -24,7 +24,7 @@ func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) {
logoutCookie, err := h.getLogoutCookie(r)
if err != nil {
logger.Warn().Msgf("logout/callback: getting cookie: %+v", err)
http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect)
http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect)
return
}
@@ -34,13 +34,13 @@ func (h *Handler) LogoutCallback(w http.ResponseWriter, r *http.Request) {
if expectedState != actualState {
logger.Warn().Msgf("logout/callback: state parameter mismatch: expected %s, got %s; falling back to ingress", expectedState, actualState)
http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect)
http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect)
return
}
if len(logoutCookie.RedirectTo) == 0 {
logger.Warn().Msgf("logout/callback: empty redirect; falling back to ingress")
http.Redirect(w, r, h.Config.Ingress, http.StatusTemporaryRedirect)
http.Redirect(w, r, h.Cfg.Wonderwall().Ingress, http.StatusTemporaryRedirect)
return
}

View File

@@ -13,9 +13,9 @@ func New(handler *Handler) chi.Router {
r := chi.NewRouter()
r.Use(middleware.CorrelationIDHandler)
r.Use(chi_middleware.Recoverer)
prometheusMiddleware := middleware.NewPrometheusMiddleware("wonderwall", string(handler.Config.OpenID.Provider))
prometheusMiddleware := middleware.NewPrometheusMiddleware("wonderwall", handler.Cfg.ProviderName())
prefix := config.ParseIngress(handler.Config.Ingress)
prefix := config.ParseIngress(handler.Cfg.Wonderwall().Ingress)
r.Route(prefix+paths.OAuth2, func(r chi.Router) {
r.Use(middleware.LogEntryHandler(handler.Httplogger))

View File

@@ -23,7 +23,7 @@ import (
// Thus, we cannot assume that the value of `sid` or `session_state` to uniquely identify the pair of (user, application session)
// if using a shared session store.
func (h *Handler) localSessionID(sessionID string) string {
return fmt.Sprintf("%s:%s:%s", h.Config.OpenID.Provider, h.OpenIDConfig.Client().GetClientID(), sessionID)
return fmt.Sprintf("%s:%s:%s", h.Cfg.ProviderName(), h.Cfg.Client().GetClientID(), sessionID)
}
func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (*session.Data, error) {
@@ -67,7 +67,7 @@ func (h *Handler) getSession(ctx context.Context, sessionID string) (*session.Da
}
func (h *Handler) getSessionLifetime(tokenExpiry time.Time) time.Duration {
defaultSessionLifetime := h.Config.SessionMaxLifetime
defaultSessionLifetime := h.Cfg.Wonderwall().SessionMaxLifetime
tokenDuration := tokenExpiry.Sub(time.Now())
@@ -79,7 +79,7 @@ func (h *Handler) getSessionLifetime(tokenExpiry time.Time) time.Duration {
}
func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *jwt.Tokens, rawTokens *oauth2.Token, params url.Values) error {
externalSessionID, err := session.NewSessionID(h.OpenIDConfig.Provider(), tokens.IDToken, params)
externalSessionID, err := session.NewSessionID(h.Cfg.Provider(), tokens.IDToken, params)
if err != nil {
return fmt.Errorf("generating session ID: %w", err)
}