From c702f8ff6cc176c2c0f6194c10f7bb53efba684d Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Sat, 16 Oct 2021 10:42:49 +0200 Subject: [PATCH] refactor: introduce generic provider for openid configs --- cmd/wonderwall/main.go | 15 ++---- pkg/provider/assertion.go | 41 ++++++++++++++++ pkg/provider/provider.go | 88 ++++++++++++++++++++++++++++++++++ pkg/router/handler.go | 23 ++++----- pkg/router/handler_callback.go | 13 ++--- pkg/router/handler_logout.go | 11 +++-- pkg/router/login_url.go | 18 +++---- pkg/router/session.go | 2 +- 8 files changed, 169 insertions(+), 42 deletions(-) create mode 100644 pkg/provider/assertion.go create mode 100644 pkg/provider/provider.go diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index 40646f4..246c017 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -66,20 +66,15 @@ func run() error { } } - crypt := cryptutil.New(key) - - sessionStore := setupSessionStore(cfg) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - jwkSet, err := jwk.Fetch(ctx, cfg.IDPorten.WellKnown.JwksURI) + prv, err := provider.NewProvider(cfg) if err != nil { - return fmt.Errorf("fetching jwks: %w", err) + return err } + crypt := cryptutil.New(key) + sessionStore := setupSessionStore(cfg) httplogger := logging.NewHttpLogger(cfg) - handler, err := router.NewHandler(*cfg, crypt, httplogger, jwkSet, sessionStore, cfg.UpstreamHost) + handler, err := router.NewHandler(*cfg, crypt, httplogger, prv, sessionStore) if err != nil { return fmt.Errorf("initializing routing handler: %w", err) } diff --git a/pkg/provider/assertion.go b/pkg/provider/assertion.go new file mode 100644 index 0000000..45d9348 --- /dev/null +++ b/pkg/provider/assertion.go @@ -0,0 +1,41 @@ +package provider + +import ( + "fmt" + "time" + + "github.com/google/uuid" + "github.com/lestrrat-go/jwx/jwa" + "github.com/lestrrat-go/jwx/jwt" +) + +func ClientAssertion(provider Provider, expiration time.Duration) (string, error) { + key := provider.GetClientConfiguration().GetClientJWK() + + iat := time.Now() + exp := iat.Add(expiration) + + errs := make([]error, 0) + + tok := jwt.New() + errs = append(errs, tok.Set(jwt.IssuerKey, provider.GetClientConfiguration().GetClientID())) + errs = append(errs, tok.Set(jwt.SubjectKey, provider.GetClientConfiguration().GetClientID())) + errs = append(errs, tok.Set(jwt.AudienceKey, provider.GetOpenIDConfiguration().Issuer)) + errs = append(errs, tok.Set("scope", provider.GetClientConfiguration().GetScopes().String())) + errs = append(errs, tok.Set(jwt.IssuedAtKey, iat)) + errs = append(errs, tok.Set(jwt.ExpirationKey, exp)) + errs = append(errs, tok.Set(jwt.JwtIDKey, uuid.New().String())) + + for _, err := range errs { + if err != nil { + return "", fmt.Errorf("setting claim for client assertion: %w", err) + } + } + + encoded, err := jwt.Sign(tok, jwa.SignatureAlgorithm(key.Algorithm()), key) + if err != nil { + return "", fmt.Errorf("signing client assertion: %w", err) + } + + return string(encoded), nil +} diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go new file mode 100644 index 0000000..3aea6ee --- /dev/null +++ b/pkg/provider/provider.go @@ -0,0 +1,88 @@ +package provider + +import ( + "context" + "fmt" + + "github.com/lestrrat-go/jwx/jwk" + + "github.com/nais/wonderwall/pkg/config" + "github.com/nais/wonderwall/pkg/openid" +) + +type Provider interface { + GetClientConfiguration() openid.ClientConfiguration + GetOpenIDConfiguration() *openid.Configuration + GetPublicJwkSet() *jwk.Set +} + +type provider struct { + clientConfiguration openid.ClientConfiguration + configuration *openid.Configuration + jwkSet *jwk.Set +} + +func (p provider) GetClientConfiguration() openid.ClientConfiguration { + return p.clientConfiguration +} + +func (p provider) GetOpenIDConfiguration() *openid.Configuration { + return p.configuration +} + +func (p provider) GetPublicJwkSet() *jwk.Set { + return p.jwkSet +} + +func NewProvider(cfg *config.Config) (Provider, error) { + clientJwkString := cfg.OpenID.ClientJWK + if len(clientJwkString) == 0 { + return nil, fmt.Errorf("missing required config %s", config.OpenIDClientJWK) + } + + clientJwk, err := jwk.ParseKey([]byte(clientJwkString)) + if err != nil { + return nil, fmt.Errorf("parsing client JWK: %w", err) + } + + baseConfig := cfg.NewBaseConfig(clientJwk) + var clientConfig openid.ClientConfiguration + switch cfg.OpenID.Provider { + case "idporten": + clientConfig = baseConfig.IDPorten() + case "azure": + clientConfig = baseConfig.Azure() + case "": + return nil, fmt.Errorf("missing required config %s", config.OpenIDProvider) + default: + clientConfig = baseConfig + } + + if len(clientConfig.GetClientID()) == 0 { + return nil, fmt.Errorf("missing required config %s", config.OpenIDClientID) + } + + if len(clientConfig.GetWellKnownURL()) == 0 { + return nil, fmt.Errorf("missing required config %s", config.OpenIDWellKnownURL) + } + + if len(clientConfig.GetRedirectURI()) == 0 { + return nil, fmt.Errorf("missing required config %s", config.OpenIDRedirectURI) + } + + configuration, err := openid.FetchWellKnownConfig(clientConfig.GetWellKnownURL()) + if err != nil { + return nil, fmt.Errorf("fetching well known config: %w", err) + } + + jwkSet, err := configuration.FetchJwkSet(context.Background()) + if err != nil { + return nil, fmt.Errorf("fetching jwk set: %w", err) + } + + return &provider{ + clientConfiguration: clientConfig, + configuration: configuration, + jwkSet: jwkSet, + }, nil +} diff --git a/pkg/router/handler.go b/pkg/router/handler.go index e5b3562..64dd060 100644 --- a/pkg/router/handler.go +++ b/pkg/router/handler.go @@ -1,14 +1,14 @@ package router import ( - "github.com/rs/zerolog" "sync" - "github.com/lestrrat-go/jwx/jwk" + "github.com/rs/zerolog" "golang.org/x/oauth2" "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/cryptutil" + "github.com/nais/wonderwall/pkg/provider" "github.com/nais/wonderwall/pkg/session" ) @@ -16,10 +16,9 @@ type Handler struct { Config config.Config Crypter cryptutil.Crypter OauthConfig oauth2.Config + Provider provider.Provider SecureCookies bool Sessions session.Store - UpstreamHost string - jwkSet jwk.Set lock sync.Mutex httplogger zerolog.Logger } @@ -28,30 +27,28 @@ func NewHandler( cfg config.Config, crypter cryptutil.Crypter, httplogger zerolog.Logger, - jwkSet jwk.Set, + provider provider.Provider, sessionStore session.Store, - upstreamHost string, ) (*Handler, error) { oauthConfig := oauth2.Config{ - ClientID: cfg.IDPorten.ClientID, + ClientID: provider.GetClientConfiguration().GetClientID(), Endpoint: oauth2.Endpoint{ - AuthURL: cfg.IDPorten.WellKnown.AuthorizationEndpoint, - TokenURL: cfg.IDPorten.WellKnown.TokenEndpoint, + AuthURL: provider.GetOpenIDConfiguration().AuthorizationEndpoint, + TokenURL: provider.GetOpenIDConfiguration().TokenEndpoint, }, - RedirectURL: cfg.IDPorten.RedirectURI, - Scopes: cfg.IDPorten.Scopes, + RedirectURL: provider.GetClientConfiguration().GetRedirectURI(), + Scopes: provider.GetClientConfiguration().GetScopes(), } return &Handler{ Config: cfg, Crypter: crypter, httplogger: httplogger, - jwkSet: jwkSet, lock: sync.Mutex{}, OauthConfig: oauthConfig, + Provider: provider, Sessions: sessionStore, SecureCookies: true, - UpstreamHost: upstreamHost, }, nil } diff --git a/pkg/router/handler_callback.go b/pkg/router/handler_callback.go index 4649c26..4a9a613 100644 --- a/pkg/router/handler_callback.go +++ b/pkg/router/handler_callback.go @@ -9,8 +9,8 @@ import ( "github.com/lestrrat-go/jwx/jwt" "golang.org/x/oauth2" - "github.com/nais/wonderwall/pkg/auth" "github.com/nais/wonderwall/pkg/cookie" + "github.com/nais/wonderwall/pkg/provider" "github.com/nais/wonderwall/pkg/token" ) @@ -40,7 +40,8 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { return } - idToken, err := token.ParseIDToken(h.jwkSet, tokens) + jwkSet := h.Provider.GetPublicJwkSet() + idToken, err := token.ParseIDToken(*jwkSet, tokens) if err != nil { h.InternalError(w, r, fmt.Errorf("callback: parsing id_token: %w", err)) return @@ -65,7 +66,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { } func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *cookie.Login, code string) (*oauth2.Token, error) { - assertion, err := auth.ClientAssertion(h.Config.IDPorten, time.Second*30) + assertion, err := provider.ClientAssertion(h.Provider, time.Second*30) if err != nil { return nil, fmt.Errorf("creating client assertion: %w", err) } @@ -86,14 +87,14 @@ func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *cookie. func (h *Handler) validateIDToken(idToken *token.IDToken, loginCookie *cookie.Login) (string, error) { validateOpts := []jwt.ValidateOption{ - jwt.WithAudience(h.Config.IDPorten.ClientID), + jwt.WithAudience(h.Provider.GetClientConfiguration().GetClientID()), jwt.WithClaimValue("nonce", loginCookie.Nonce), - jwt.WithIssuer(h.Config.IDPorten.WellKnown.Issuer), + jwt.WithIssuer(h.Provider.GetOpenIDConfiguration().Issuer), jwt.WithAcceptableSkew(5 * time.Second), jwt.WithRequiredClaim("sid"), } - if h.Config.IDPorten.SecurityLevel.Enabled { + if h.Provider.GetClientConfiguration().GetACRValues().Enabled { validateOpts = append(validateOpts, jwt.WithRequiredClaim("acr")) } diff --git a/pkg/router/handler_logout.go b/pkg/router/handler_logout.go index 1c993e7..3e6b38a 100644 --- a/pkg/router/handler_logout.go +++ b/pkg/router/handler_logout.go @@ -2,14 +2,15 @@ package router import ( "fmt" - "github.com/nais/wonderwall/pkg/request" "net/http" "net/url" + + "github.com/nais/wonderwall/pkg/request" ) // Logout triggers self-initiated for the current user func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { - u, err := url.Parse(h.Config.IDPorten.WellKnown.EndSessionEndpoint) + u, err := url.Parse(h.Provider.GetOpenIDConfiguration().EndSessionEndpoint) if err != nil { h.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err)) return @@ -30,7 +31,11 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { h.deleteCookie(w, h.GetSessionCookieName()) v := u.Query() - v.Add("post_logout_redirect_uri", request.PostLogoutRedirectURI(r, h.Config.IDPorten.PostLogoutRedirectURI)) + + postLogoutURI := request.PostLogoutRedirectURI(r, h.Provider.GetClientConfiguration().GetPostLogoutRedirectURI()) + if len(postLogoutURI) > 0 { + v.Add("post_logout_redirect_uri", postLogoutURI) + } if len(idToken) != 0 { v.Add("id_token_hint", idToken) diff --git a/pkg/router/login_url.go b/pkg/router/login_url.go index d75d931..ec2b5ae 100644 --- a/pkg/router/login_url.go +++ b/pkg/router/login_url.go @@ -23,9 +23,9 @@ func (h *Handler) LoginURL(r *http.Request, params *openid.Parameters) (string, v := u.Query() v.Add("response_type", "code") - v.Add("client_id", h.Config.IDPorten.ClientID) - v.Add("redirect_uri", h.Config.IDPorten.RedirectURI) - v.Add("scope", token.ScopeOpenID) + v.Add("client_id", h.Provider.GetClientConfiguration().GetClientID()) + v.Add("redirect_uri", h.Provider.GetClientConfiguration().GetRedirectURI()) + v.Add("scope", h.Provider.GetClientConfiguration().GetScopes().String()) v.Add("state", params.State) v.Add("nonce", params.Nonce) v.Add("response_mode", "query") @@ -48,12 +48,12 @@ func (h *Handler) LoginURL(r *http.Request, params *openid.Parameters) (string, } func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error { - if !h.Config.IDPorten.SecurityLevel.Enabled { + if !h.Provider.GetClientConfiguration().GetACRValues().Enabled { return nil } - fallback := h.Config.IDPorten.SecurityLevel.Value - supported := h.Config.IDPorten.WellKnown.ACRValuesSupported + fallback := h.Provider.GetClientConfiguration().GetACRValues().Value + supported := h.Provider.GetOpenIDConfiguration().ACRValuesSupported securityLevel, err := request.LoginURLParameter(r, request.SecurityLevelURLParameter, fallback, supported) if err != nil { @@ -65,12 +65,12 @@ func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error { } func (h *Handler) withLocale(r *http.Request, v url.Values) error { - if !h.Config.IDPorten.Locale.Enabled { + if !h.Provider.GetClientConfiguration().GetUILocales().Enabled { return nil } - fallback := h.Config.IDPorten.Locale.Value - supported := h.Config.IDPorten.WellKnown.UILocalesSupported + fallback := h.Provider.GetClientConfiguration().GetUILocales().Value + supported := h.Provider.GetOpenIDConfiguration().UILocalesSupported locale, err := request.LoginURLParameter(r, request.LocaleURLParameter, fallback, supported) if err != nil { diff --git a/pkg/router/session.go b/pkg/router/session.go index f20dc36..d740b73 100644 --- a/pkg/router/session.go +++ b/pkg/router/session.go @@ -21,7 +21,7 @@ import ( // Thus, we cannot assume that the value of `sid` to uniquely identify the pair of (user, application session) // if using a shared session store. func (h *Handler) localSessionID(sid string) string { - return fmt.Sprintf("%s-%s", h.Config.IDPorten.ClientID, sid) + return fmt.Sprintf("%s:%s:%s", h.Config.OpenID.Provider, h.Provider.GetClientConfiguration().GetClientID(), sid) } func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (*session.Data, error) {