refactor(openid/provider): use name from config instead of indirection layer

This commit is contained in:
Trong Huu Nguyen
2023-01-31 14:19:57 +01:00
parent 2f6a3682d9
commit bd748b9cef
5 changed files with 15 additions and 23 deletions

View File

@@ -54,7 +54,7 @@ func run() error {
r := router.New(h)
go func() {
err := metrics.Handle(cfg.MetricsBindAddress, openidConfig)
err := metrics.Handle(cfg.MetricsBindAddress, cfg.OpenID.Provider)
if err != nil {
log.Fatalf("fatal: metrics server error: %s", err)
}

View File

@@ -83,7 +83,7 @@ func (s *StandardHandler) GetPath(r *http.Request) string {
}
func (s *StandardHandler) GetProviderName() string {
return s.openidConfig.Provider().Name()
return string(s.config.OpenID.Provider)
}
func (s *StandardHandler) GetSessions() *sessionStore.Handler {

View File

@@ -7,7 +7,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
"github.com/nais/wonderwall/pkg/config"
)
const (
@@ -121,8 +121,8 @@ func InitLabels() {
}
}
func Handle(address string, openidConfig openidconfig.Config) error {
WithProvider(openidConfig.Provider().Name())
func Handle(address string, provider config.Provider) error {
WithProvider(string(provider))
Register(prometheus.DefaultRegisterer)
InitLabels()

View File

@@ -21,7 +21,6 @@ type Provider interface {
ACRValuesSupported() Supported
UILocalesSupported() Supported
Name() string
SessionStateRequired() bool
SidClaimRequired() bool
}
@@ -29,7 +28,6 @@ type Provider interface {
type provider struct {
endSessionEndpointURL *url.URL
metadata *ProviderMetadata
name string
}
func (p *provider) AuthorizationEndpoint() string {
@@ -60,10 +58,6 @@ func (p *provider) UILocalesSupported() Supported {
return p.metadata.UILocalesSupported
}
func (p *provider) Name() string {
return p.name
}
func (p *provider) SessionStateRequired() bool {
return len(p.metadata.CheckSessionIframe) > 0
}
@@ -104,7 +98,6 @@ func NewProviderConfig(cfg *wonderwallconfig.Config) (Provider, error) {
return &provider{
endSessionEndpointURL: endSessionEndpointURL,
metadata: providerCfg,
name: string(cfg.OpenID.Provider),
}, nil
}

View File

@@ -34,7 +34,7 @@ const (
)
type Handler struct {
cfg config.Session
cfg *config.Config
client *openidclient.Client
crypter crypto.Crypter
openidCfg openidconfig.Config
@@ -52,7 +52,7 @@ func NewHandler(cfg *config.Config, openidCfg openidconfig.Config, crypter crypt
client: openidClient,
openidCfg: openidCfg,
store: store,
cfg: cfg.Session,
cfg: cfg,
}, nil
}
@@ -67,8 +67,8 @@ func (h *Handler) Create(r *http.Request, tokens *openid.Tokens, sessionLifetime
tokenExpiresIn := time.Until(tokens.Expiry)
metadata := NewMetadata(tokenExpiresIn, sessionLifetime)
if h.cfg.Inactivity {
metadata.WithTimeout(h.cfg.InactivityTimeout)
if h.cfg.Session.Inactivity {
metadata.WithTimeout(h.cfg.Session.InactivityTimeout)
}
encrypted, err := NewData(externalSessionID, tokens, metadata).Encrypt(h.crypter)
@@ -224,10 +224,9 @@ func (h *Handler) IDOrGenerate(r *http.Request, tokens *openid.Tokens) (string,
// the value of `sid` or `session_state` to uniquely identify the pair of (user, application session) if using a shared
// session store across multiple Relying Parties.
func (h *Handler) Key(sessionID string) string {
provider := h.openidCfg.Provider()
client := h.openidCfg.Client()
return fmt.Sprintf("%s:%s:%s", provider.Name(), client.ClientID(), sessionID)
return fmt.Sprintf("%s:%s:%s", h.cfg.OpenID.Provider, client.ClientID(), sessionID)
}
// Refresh refreshes the user's session and returns the updated session data.
@@ -314,8 +313,8 @@ func (h *Handler) Refresh(r *http.Request, key string, data *Data) (*Data, error
data.RefreshToken = resp.RefreshToken
data.Metadata.Refresh(resp.ExpiresIn)
if h.cfg.Inactivity {
data.Metadata.ExtendTimeout(h.cfg.InactivityTimeout)
if h.cfg.Session.Inactivity {
data.Metadata.ExtendTimeout(h.cfg.Session.InactivityTimeout)
}
err = h.Update(ctx, key, data)
@@ -349,15 +348,15 @@ func (h *Handler) Update(ctx context.Context, key string, data *Data) error {
}
func (h *Handler) canRefresh(data *Data) bool {
return h.cfg.Refresh && data.HasRefreshToken() && !data.Metadata.IsRefreshOnCooldown()
return h.cfg.Session.Refresh && data.HasRefreshToken() && !data.Metadata.IsRefreshOnCooldown()
}
func (h *Handler) shouldRefresh(data *Data) bool {
return h.cfg.Refresh && data.HasRefreshToken() && data.Metadata.ShouldRefresh()
return h.cfg.Session.Refresh && data.HasRefreshToken() && data.Metadata.ShouldRefresh()
}
func (h *Handler) isTimedOut(data *Data) bool {
return h.cfg.Inactivity && data.Metadata.IsTimedOut()
return h.cfg.Session.Inactivity && data.Metadata.IsTimedOut()
}
func NewSessionID(cfg openidconfig.Provider, idToken *openid.IDToken, params url.Values) (string, error) {