mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-20 23:32:57 +00:00
165 lines
4.6 KiB
Go
165 lines
4.6 KiB
Go
package openid
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/lestrrat-go/jwx/jwk"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/nais/wonderwall/pkg/config"
|
|
"github.com/nais/wonderwall/pkg/openid/clients"
|
|
)
|
|
|
|
const (
|
|
JwkMinimumRefreshInterval = 5 * time.Second
|
|
)
|
|
|
|
type Provider interface {
|
|
GetClientConfiguration() clients.Configuration
|
|
GetOpenIDConfiguration() *Configuration
|
|
GetPublicJwkSet(ctx context.Context) (*jwk.Set, error)
|
|
RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error)
|
|
}
|
|
|
|
type provider struct {
|
|
clientConfiguration clients.Configuration
|
|
configuration *Configuration
|
|
jwks *jwk.AutoRefresh
|
|
jwksLock *jwksLock
|
|
}
|
|
|
|
type jwksLock struct {
|
|
lastRefresh time.Time
|
|
sync.Mutex
|
|
}
|
|
|
|
func (p provider) GetClientConfiguration() clients.Configuration {
|
|
return p.clientConfiguration
|
|
}
|
|
|
|
func (p provider) GetOpenIDConfiguration() *Configuration {
|
|
return p.configuration
|
|
}
|
|
|
|
func (p provider) GetPublicJwkSet(ctx context.Context) (*jwk.Set, error) {
|
|
url := p.configuration.JwksURI
|
|
set, err := p.jwks.Fetch(ctx, url)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("provider: fetching jwks: %w", err)
|
|
}
|
|
|
|
return &set, nil
|
|
}
|
|
|
|
func (p provider) RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error) {
|
|
p.jwksLock.Lock()
|
|
defer p.jwksLock.Unlock()
|
|
|
|
// redirect to cache if recently refreshed to avoid overwhelming provider
|
|
diff := time.Now().Sub(p.jwksLock.lastRefresh)
|
|
if diff < JwkMinimumRefreshInterval {
|
|
return p.GetPublicJwkSet(ctx)
|
|
}
|
|
|
|
p.jwksLock.lastRefresh = time.Now()
|
|
|
|
url := p.configuration.JwksURI
|
|
set, err := p.jwks.Refresh(ctx, url)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("provider: refreshing jwks: %w", err)
|
|
}
|
|
|
|
return &set, nil
|
|
}
|
|
|
|
func NewProvider(ctx context.Context, cfg *config.Config) (Provider, error) {
|
|
clientJwkString := cfg.OpenID.ClientJWK
|
|
if len(clientJwkString) == 0 {
|
|
return nil, fmt.Errorf("missing required config %s", config.OpenIDClientJWK)
|
|
}
|
|
|
|
clientJwk, err := jwk.ParseKey([]byte(clientJwkString))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing client JWK: %w", err)
|
|
}
|
|
|
|
ingress := cfg.Ingress
|
|
if len(ingress) == 0 {
|
|
return nil, fmt.Errorf("missing required config %s", config.Ingress)
|
|
}
|
|
|
|
redirectURI, err := RedirectURI(ingress)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating redirect URI from ingress: %w", err)
|
|
}
|
|
|
|
openIDConfig := clients.NewOpenIDConfig(*cfg, clientJwk, redirectURI)
|
|
var clientConfig clients.Configuration
|
|
switch cfg.OpenID.Provider {
|
|
case config.ProviderIDPorten:
|
|
clientConfig = openIDConfig.IDPorten()
|
|
case config.ProviderAzure:
|
|
clientConfig = openIDConfig.Azure()
|
|
case "":
|
|
return nil, fmt.Errorf("missing required config %s", config.OpenIDProvider)
|
|
default:
|
|
clientConfig = openIDConfig
|
|
}
|
|
|
|
if len(clientConfig.GetClientID()) == 0 {
|
|
return nil, fmt.Errorf("missing required config %s", config.OpenIDClientID)
|
|
}
|
|
|
|
if len(clientConfig.GetWellKnownURL()) == 0 {
|
|
return nil, fmt.Errorf("missing required config %s", config.OpenIDWellKnownURL)
|
|
}
|
|
|
|
configuration, err := FetchWellKnownConfig(clientConfig.GetWellKnownURL())
|
|
if err != nil {
|
|
return nil, fmt.Errorf("fetching well known config: %w", err)
|
|
}
|
|
|
|
printConfigs(clientConfig, *configuration)
|
|
|
|
acrValues := clientConfig.GetACRValues()
|
|
if len(acrValues) > 0 && !configuration.ACRValuesSupported.Contains(acrValues) {
|
|
return nil, fmt.Errorf("identity provider does not support '%s=%s'", config.OpenIDACRValues, acrValues)
|
|
}
|
|
|
|
uiLocales := clientConfig.GetUILocales()
|
|
if len(uiLocales) > 0 && !configuration.UILocalesSupported.Contains(uiLocales) {
|
|
return nil, fmt.Errorf("identity provider does not support '%s=%s'", config.OpenIDUILocales, acrValues)
|
|
}
|
|
|
|
uri := configuration.JwksURI
|
|
jwksAutoRefresh := jwk.NewAutoRefresh(ctx)
|
|
jwksAutoRefresh.Configure(uri)
|
|
_, err = jwksAutoRefresh.Fetch(ctx, uri)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("initial fetch of jwks from provider: %w", err)
|
|
}
|
|
|
|
return &provider{
|
|
clientConfiguration: clientConfig,
|
|
configuration: configuration,
|
|
jwks: jwksAutoRefresh,
|
|
jwksLock: &jwksLock{},
|
|
}, nil
|
|
}
|
|
|
|
func printConfigs(clientCfg clients.Configuration, openIdCfg Configuration) {
|
|
log.Info("🤔 openid client configuration 🤔")
|
|
log.Infof("acr values: '%s'", clientCfg.GetACRValues())
|
|
log.Infof("client id: '%s'", clientCfg.GetClientID())
|
|
log.Infof("post-logout redirect uri: '%s'", clientCfg.GetPostLogoutRedirectURI())
|
|
log.Infof("redirect uri: '%s'", clientCfg.GetRedirectURI())
|
|
log.Infof("scopes: '%s'", clientCfg.GetScopes())
|
|
log.Infof("ui locales: '%s'", clientCfg.GetUILocales())
|
|
|
|
log.Info("😗 openid provider configuration 😗")
|
|
log.Infof("%#v", openIdCfg)
|
|
}
|