mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-09 09:56:48 +00:00
refactor: introduce generic provider for openid configs
This commit is contained in:
@@ -66,20 +66,15 @@ func run() error {
|
||||
}
|
||||
}
|
||||
|
||||
crypt := cryptutil.New(key)
|
||||
|
||||
sessionStore := setupSessionStore(cfg)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
jwkSet, err := jwk.Fetch(ctx, cfg.IDPorten.WellKnown.JwksURI)
|
||||
prv, err := provider.NewProvider(cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching jwks: %w", err)
|
||||
return err
|
||||
}
|
||||
|
||||
crypt := cryptutil.New(key)
|
||||
sessionStore := setupSessionStore(cfg)
|
||||
httplogger := logging.NewHttpLogger(cfg)
|
||||
handler, err := router.NewHandler(*cfg, crypt, httplogger, jwkSet, sessionStore, cfg.UpstreamHost)
|
||||
handler, err := router.NewHandler(*cfg, crypt, httplogger, prv, sessionStore)
|
||||
if err != nil {
|
||||
return fmt.Errorf("initializing routing handler: %w", err)
|
||||
}
|
||||
|
||||
41
pkg/provider/assertion.go
Normal file
41
pkg/provider/assertion.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
)
|
||||
|
||||
func ClientAssertion(provider Provider, expiration time.Duration) (string, error) {
|
||||
key := provider.GetClientConfiguration().GetClientJWK()
|
||||
|
||||
iat := time.Now()
|
||||
exp := iat.Add(expiration)
|
||||
|
||||
errs := make([]error, 0)
|
||||
|
||||
tok := jwt.New()
|
||||
errs = append(errs, tok.Set(jwt.IssuerKey, provider.GetClientConfiguration().GetClientID()))
|
||||
errs = append(errs, tok.Set(jwt.SubjectKey, provider.GetClientConfiguration().GetClientID()))
|
||||
errs = append(errs, tok.Set(jwt.AudienceKey, provider.GetOpenIDConfiguration().Issuer))
|
||||
errs = append(errs, tok.Set("scope", provider.GetClientConfiguration().GetScopes().String()))
|
||||
errs = append(errs, tok.Set(jwt.IssuedAtKey, iat))
|
||||
errs = append(errs, tok.Set(jwt.ExpirationKey, exp))
|
||||
errs = append(errs, tok.Set(jwt.JwtIDKey, uuid.New().String()))
|
||||
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("setting claim for client assertion: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
encoded, err := jwt.Sign(tok, jwa.SignatureAlgorithm(key.Algorithm()), key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("signing client assertion: %w", err)
|
||||
}
|
||||
|
||||
return string(encoded), nil
|
||||
}
|
||||
88
pkg/provider/provider.go
Normal file
88
pkg/provider/provider.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
)
|
||||
|
||||
type Provider interface {
|
||||
GetClientConfiguration() openid.ClientConfiguration
|
||||
GetOpenIDConfiguration() *openid.Configuration
|
||||
GetPublicJwkSet() *jwk.Set
|
||||
}
|
||||
|
||||
type provider struct {
|
||||
clientConfiguration openid.ClientConfiguration
|
||||
configuration *openid.Configuration
|
||||
jwkSet *jwk.Set
|
||||
}
|
||||
|
||||
func (p provider) GetClientConfiguration() openid.ClientConfiguration {
|
||||
return p.clientConfiguration
|
||||
}
|
||||
|
||||
func (p provider) GetOpenIDConfiguration() *openid.Configuration {
|
||||
return p.configuration
|
||||
}
|
||||
|
||||
func (p provider) GetPublicJwkSet() *jwk.Set {
|
||||
return p.jwkSet
|
||||
}
|
||||
|
||||
func NewProvider(cfg *config.Config) (Provider, error) {
|
||||
clientJwkString := cfg.OpenID.ClientJWK
|
||||
if len(clientJwkString) == 0 {
|
||||
return nil, fmt.Errorf("missing required config %s", config.OpenIDClientJWK)
|
||||
}
|
||||
|
||||
clientJwk, err := jwk.ParseKey([]byte(clientJwkString))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing client JWK: %w", err)
|
||||
}
|
||||
|
||||
baseConfig := cfg.NewBaseConfig(clientJwk)
|
||||
var clientConfig openid.ClientConfiguration
|
||||
switch cfg.OpenID.Provider {
|
||||
case "idporten":
|
||||
clientConfig = baseConfig.IDPorten()
|
||||
case "azure":
|
||||
clientConfig = baseConfig.Azure()
|
||||
case "":
|
||||
return nil, fmt.Errorf("missing required config %s", config.OpenIDProvider)
|
||||
default:
|
||||
clientConfig = baseConfig
|
||||
}
|
||||
|
||||
if len(clientConfig.GetClientID()) == 0 {
|
||||
return nil, fmt.Errorf("missing required config %s", config.OpenIDClientID)
|
||||
}
|
||||
|
||||
if len(clientConfig.GetWellKnownURL()) == 0 {
|
||||
return nil, fmt.Errorf("missing required config %s", config.OpenIDWellKnownURL)
|
||||
}
|
||||
|
||||
if len(clientConfig.GetRedirectURI()) == 0 {
|
||||
return nil, fmt.Errorf("missing required config %s", config.OpenIDRedirectURI)
|
||||
}
|
||||
|
||||
configuration, err := openid.FetchWellKnownConfig(clientConfig.GetWellKnownURL())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching well known config: %w", err)
|
||||
}
|
||||
|
||||
jwkSet, err := configuration.FetchJwkSet(context.Background())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching jwk set: %w", err)
|
||||
}
|
||||
|
||||
return &provider{
|
||||
clientConfiguration: clientConfig,
|
||||
configuration: configuration,
|
||||
jwkSet: jwkSet,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,14 +1,14 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"github.com/rs/zerolog"
|
||||
"sync"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/config"
|
||||
"github.com/nais/wonderwall/pkg/cryptutil"
|
||||
"github.com/nais/wonderwall/pkg/provider"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
@@ -16,10 +16,9 @@ type Handler struct {
|
||||
Config config.Config
|
||||
Crypter cryptutil.Crypter
|
||||
OauthConfig oauth2.Config
|
||||
Provider provider.Provider
|
||||
SecureCookies bool
|
||||
Sessions session.Store
|
||||
UpstreamHost string
|
||||
jwkSet jwk.Set
|
||||
lock sync.Mutex
|
||||
httplogger zerolog.Logger
|
||||
}
|
||||
@@ -28,30 +27,28 @@ func NewHandler(
|
||||
cfg config.Config,
|
||||
crypter cryptutil.Crypter,
|
||||
httplogger zerolog.Logger,
|
||||
jwkSet jwk.Set,
|
||||
provider provider.Provider,
|
||||
sessionStore session.Store,
|
||||
upstreamHost string,
|
||||
) (*Handler, error) {
|
||||
oauthConfig := oauth2.Config{
|
||||
ClientID: cfg.IDPorten.ClientID,
|
||||
ClientID: provider.GetClientConfiguration().GetClientID(),
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: cfg.IDPorten.WellKnown.AuthorizationEndpoint,
|
||||
TokenURL: cfg.IDPorten.WellKnown.TokenEndpoint,
|
||||
AuthURL: provider.GetOpenIDConfiguration().AuthorizationEndpoint,
|
||||
TokenURL: provider.GetOpenIDConfiguration().TokenEndpoint,
|
||||
},
|
||||
RedirectURL: cfg.IDPorten.RedirectURI,
|
||||
Scopes: cfg.IDPorten.Scopes,
|
||||
RedirectURL: provider.GetClientConfiguration().GetRedirectURI(),
|
||||
Scopes: provider.GetClientConfiguration().GetScopes(),
|
||||
}
|
||||
|
||||
return &Handler{
|
||||
Config: cfg,
|
||||
Crypter: crypter,
|
||||
httplogger: httplogger,
|
||||
jwkSet: jwkSet,
|
||||
lock: sync.Mutex{},
|
||||
OauthConfig: oauthConfig,
|
||||
Provider: provider,
|
||||
Sessions: sessionStore,
|
||||
SecureCookies: true,
|
||||
UpstreamHost: upstreamHost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/auth"
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
"github.com/nais/wonderwall/pkg/provider"
|
||||
"github.com/nais/wonderwall/pkg/token"
|
||||
)
|
||||
|
||||
@@ -40,7 +40,8 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
idToken, err := token.ParseIDToken(h.jwkSet, tokens)
|
||||
jwkSet := h.Provider.GetPublicJwkSet()
|
||||
idToken, err := token.ParseIDToken(*jwkSet, tokens)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("callback: parsing id_token: %w", err))
|
||||
return
|
||||
@@ -65,7 +66,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *cookie.Login, code string) (*oauth2.Token, error) {
|
||||
assertion, err := auth.ClientAssertion(h.Config.IDPorten, time.Second*30)
|
||||
assertion, err := provider.ClientAssertion(h.Provider, time.Second*30)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating client assertion: %w", err)
|
||||
}
|
||||
@@ -86,14 +87,14 @@ func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *cookie.
|
||||
|
||||
func (h *Handler) validateIDToken(idToken *token.IDToken, loginCookie *cookie.Login) (string, error) {
|
||||
validateOpts := []jwt.ValidateOption{
|
||||
jwt.WithAudience(h.Config.IDPorten.ClientID),
|
||||
jwt.WithAudience(h.Provider.GetClientConfiguration().GetClientID()),
|
||||
jwt.WithClaimValue("nonce", loginCookie.Nonce),
|
||||
jwt.WithIssuer(h.Config.IDPorten.WellKnown.Issuer),
|
||||
jwt.WithIssuer(h.Provider.GetOpenIDConfiguration().Issuer),
|
||||
jwt.WithAcceptableSkew(5 * time.Second),
|
||||
jwt.WithRequiredClaim("sid"),
|
||||
}
|
||||
|
||||
if h.Config.IDPorten.SecurityLevel.Enabled {
|
||||
if h.Provider.GetClientConfiguration().GetACRValues().Enabled {
|
||||
validateOpts = append(validateOpts, jwt.WithRequiredClaim("acr"))
|
||||
}
|
||||
|
||||
|
||||
@@ -2,14 +2,15 @@ package router
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/nais/wonderwall/pkg/request"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/request"
|
||||
)
|
||||
|
||||
// Logout triggers self-initiated for the current user
|
||||
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
u, err := url.Parse(h.Config.IDPorten.WellKnown.EndSessionEndpoint)
|
||||
u, err := url.Parse(h.Provider.GetOpenIDConfiguration().EndSessionEndpoint)
|
||||
if err != nil {
|
||||
h.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err))
|
||||
return
|
||||
@@ -30,7 +31,11 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
|
||||
h.deleteCookie(w, h.GetSessionCookieName())
|
||||
|
||||
v := u.Query()
|
||||
v.Add("post_logout_redirect_uri", request.PostLogoutRedirectURI(r, h.Config.IDPorten.PostLogoutRedirectURI))
|
||||
|
||||
postLogoutURI := request.PostLogoutRedirectURI(r, h.Provider.GetClientConfiguration().GetPostLogoutRedirectURI())
|
||||
if len(postLogoutURI) > 0 {
|
||||
v.Add("post_logout_redirect_uri", postLogoutURI)
|
||||
}
|
||||
|
||||
if len(idToken) != 0 {
|
||||
v.Add("id_token_hint", idToken)
|
||||
|
||||
@@ -23,9 +23,9 @@ func (h *Handler) LoginURL(r *http.Request, params *openid.Parameters) (string,
|
||||
|
||||
v := u.Query()
|
||||
v.Add("response_type", "code")
|
||||
v.Add("client_id", h.Config.IDPorten.ClientID)
|
||||
v.Add("redirect_uri", h.Config.IDPorten.RedirectURI)
|
||||
v.Add("scope", token.ScopeOpenID)
|
||||
v.Add("client_id", h.Provider.GetClientConfiguration().GetClientID())
|
||||
v.Add("redirect_uri", h.Provider.GetClientConfiguration().GetRedirectURI())
|
||||
v.Add("scope", h.Provider.GetClientConfiguration().GetScopes().String())
|
||||
v.Add("state", params.State)
|
||||
v.Add("nonce", params.Nonce)
|
||||
v.Add("response_mode", "query")
|
||||
@@ -48,12 +48,12 @@ func (h *Handler) LoginURL(r *http.Request, params *openid.Parameters) (string,
|
||||
}
|
||||
|
||||
func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error {
|
||||
if !h.Config.IDPorten.SecurityLevel.Enabled {
|
||||
if !h.Provider.GetClientConfiguration().GetACRValues().Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
fallback := h.Config.IDPorten.SecurityLevel.Value
|
||||
supported := h.Config.IDPorten.WellKnown.ACRValuesSupported
|
||||
fallback := h.Provider.GetClientConfiguration().GetACRValues().Value
|
||||
supported := h.Provider.GetOpenIDConfiguration().ACRValuesSupported
|
||||
|
||||
securityLevel, err := request.LoginURLParameter(r, request.SecurityLevelURLParameter, fallback, supported)
|
||||
if err != nil {
|
||||
@@ -65,12 +65,12 @@ func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error {
|
||||
}
|
||||
|
||||
func (h *Handler) withLocale(r *http.Request, v url.Values) error {
|
||||
if !h.Config.IDPorten.Locale.Enabled {
|
||||
if !h.Provider.GetClientConfiguration().GetUILocales().Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
fallback := h.Config.IDPorten.Locale.Value
|
||||
supported := h.Config.IDPorten.WellKnown.UILocalesSupported
|
||||
fallback := h.Provider.GetClientConfiguration().GetUILocales().Value
|
||||
supported := h.Provider.GetOpenIDConfiguration().UILocalesSupported
|
||||
|
||||
locale, err := request.LoginURLParameter(r, request.LocaleURLParameter, fallback, supported)
|
||||
if err != nil {
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
// Thus, we cannot assume that the value of `sid` to uniquely identify the pair of (user, application session)
|
||||
// if using a shared session store.
|
||||
func (h *Handler) localSessionID(sid string) string {
|
||||
return fmt.Sprintf("%s-%s", h.Config.IDPorten.ClientID, sid)
|
||||
return fmt.Sprintf("%s:%s:%s", h.Config.OpenID.Provider, h.Provider.GetClientConfiguration().GetClientID(), sid)
|
||||
}
|
||||
|
||||
func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (*session.Data, error) {
|
||||
|
||||
Reference in New Issue
Block a user