From bd748b9cef040c19399fdedb8a8739d5e626c310 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Tue, 31 Jan 2023 14:19:57 +0100 Subject: [PATCH] refactor(openid/provider): use name from config instead of indirection layer --- cmd/wonderwall/main.go | 2 +- pkg/handler/handler_standard.go | 2 +- pkg/metrics/metrics.go | 6 +++--- pkg/openid/config/provider.go | 7 ------- pkg/session/handler.go | 21 ++++++++++----------- 5 files changed, 15 insertions(+), 23 deletions(-) diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index 2f01c76..c3a0316 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -54,7 +54,7 @@ func run() error { r := router.New(h) go func() { - err := metrics.Handle(cfg.MetricsBindAddress, openidConfig) + err := metrics.Handle(cfg.MetricsBindAddress, cfg.OpenID.Provider) if err != nil { log.Fatalf("fatal: metrics server error: %s", err) } diff --git a/pkg/handler/handler_standard.go b/pkg/handler/handler_standard.go index 662775a..8bd103f 100644 --- a/pkg/handler/handler_standard.go +++ b/pkg/handler/handler_standard.go @@ -83,7 +83,7 @@ func (s *StandardHandler) GetPath(r *http.Request) string { } func (s *StandardHandler) GetProviderName() string { - return s.openidConfig.Provider().Name() + return string(s.config.OpenID.Provider) } func (s *StandardHandler) GetSessions() *sessionStore.Handler { diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index 3405a38..a120369 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -7,7 +7,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" - openidconfig "github.com/nais/wonderwall/pkg/openid/config" + "github.com/nais/wonderwall/pkg/config" ) const ( @@ -121,8 +121,8 @@ func InitLabels() { } } -func Handle(address string, openidConfig openidconfig.Config) error { - WithProvider(openidConfig.Provider().Name()) +func Handle(address string, provider config.Provider) error { + WithProvider(string(provider)) Register(prometheus.DefaultRegisterer) InitLabels() diff --git a/pkg/openid/config/provider.go b/pkg/openid/config/provider.go index ad6cdfc..846091d 100644 --- a/pkg/openid/config/provider.go +++ b/pkg/openid/config/provider.go @@ -21,7 +21,6 @@ type Provider interface { ACRValuesSupported() Supported UILocalesSupported() Supported - Name() string SessionStateRequired() bool SidClaimRequired() bool } @@ -29,7 +28,6 @@ type Provider interface { type provider struct { endSessionEndpointURL *url.URL metadata *ProviderMetadata - name string } func (p *provider) AuthorizationEndpoint() string { @@ -60,10 +58,6 @@ func (p *provider) UILocalesSupported() Supported { return p.metadata.UILocalesSupported } -func (p *provider) Name() string { - return p.name -} - func (p *provider) SessionStateRequired() bool { return len(p.metadata.CheckSessionIframe) > 0 } @@ -104,7 +98,6 @@ func NewProviderConfig(cfg *wonderwallconfig.Config) (Provider, error) { return &provider{ endSessionEndpointURL: endSessionEndpointURL, metadata: providerCfg, - name: string(cfg.OpenID.Provider), }, nil } diff --git a/pkg/session/handler.go b/pkg/session/handler.go index 5353189..8dd18c4 100644 --- a/pkg/session/handler.go +++ b/pkg/session/handler.go @@ -34,7 +34,7 @@ const ( ) type Handler struct { - cfg config.Session + cfg *config.Config client *openidclient.Client crypter crypto.Crypter openidCfg openidconfig.Config @@ -52,7 +52,7 @@ func NewHandler(cfg *config.Config, openidCfg openidconfig.Config, crypter crypt client: openidClient, openidCfg: openidCfg, store: store, - cfg: cfg.Session, + cfg: cfg, }, nil } @@ -67,8 +67,8 @@ func (h *Handler) Create(r *http.Request, tokens *openid.Tokens, sessionLifetime tokenExpiresIn := time.Until(tokens.Expiry) metadata := NewMetadata(tokenExpiresIn, sessionLifetime) - if h.cfg.Inactivity { - metadata.WithTimeout(h.cfg.InactivityTimeout) + if h.cfg.Session.Inactivity { + metadata.WithTimeout(h.cfg.Session.InactivityTimeout) } encrypted, err := NewData(externalSessionID, tokens, metadata).Encrypt(h.crypter) @@ -224,10 +224,9 @@ func (h *Handler) IDOrGenerate(r *http.Request, tokens *openid.Tokens) (string, // the value of `sid` or `session_state` to uniquely identify the pair of (user, application session) if using a shared // session store across multiple Relying Parties. func (h *Handler) Key(sessionID string) string { - provider := h.openidCfg.Provider() client := h.openidCfg.Client() - return fmt.Sprintf("%s:%s:%s", provider.Name(), client.ClientID(), sessionID) + return fmt.Sprintf("%s:%s:%s", h.cfg.OpenID.Provider, client.ClientID(), sessionID) } // Refresh refreshes the user's session and returns the updated session data. @@ -314,8 +313,8 @@ func (h *Handler) Refresh(r *http.Request, key string, data *Data) (*Data, error data.RefreshToken = resp.RefreshToken data.Metadata.Refresh(resp.ExpiresIn) - if h.cfg.Inactivity { - data.Metadata.ExtendTimeout(h.cfg.InactivityTimeout) + if h.cfg.Session.Inactivity { + data.Metadata.ExtendTimeout(h.cfg.Session.InactivityTimeout) } err = h.Update(ctx, key, data) @@ -349,15 +348,15 @@ func (h *Handler) Update(ctx context.Context, key string, data *Data) error { } func (h *Handler) canRefresh(data *Data) bool { - return h.cfg.Refresh && data.HasRefreshToken() && !data.Metadata.IsRefreshOnCooldown() + return h.cfg.Session.Refresh && data.HasRefreshToken() && !data.Metadata.IsRefreshOnCooldown() } func (h *Handler) shouldRefresh(data *Data) bool { - return h.cfg.Refresh && data.HasRefreshToken() && data.Metadata.ShouldRefresh() + return h.cfg.Session.Refresh && data.HasRefreshToken() && data.Metadata.ShouldRefresh() } func (h *Handler) isTimedOut(data *Data) bool { - return h.cfg.Inactivity && data.Metadata.IsTimedOut() + return h.cfg.Session.Inactivity && data.Metadata.IsTimedOut() } func NewSessionID(cfg openidconfig.Provider, idToken *openid.IDToken, params url.Values) (string, error) {