refactor: introduce generic provider for openid configs

This commit is contained in:
Trong Huu Nguyen
2021-10-16 10:42:49 +02:00
parent 2f0243b69a
commit c702f8ff6c
8 changed files with 169 additions and 42 deletions

View File

@@ -66,20 +66,15 @@ func run() error {
}
}
crypt := cryptutil.New(key)
sessionStore := setupSessionStore(cfg)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
jwkSet, err := jwk.Fetch(ctx, cfg.IDPorten.WellKnown.JwksURI)
prv, err := provider.NewProvider(cfg)
if err != nil {
return fmt.Errorf("fetching jwks: %w", err)
return err
}
crypt := cryptutil.New(key)
sessionStore := setupSessionStore(cfg)
httplogger := logging.NewHttpLogger(cfg)
handler, err := router.NewHandler(*cfg, crypt, httplogger, jwkSet, sessionStore, cfg.UpstreamHost)
handler, err := router.NewHandler(*cfg, crypt, httplogger, prv, sessionStore)
if err != nil {
return fmt.Errorf("initializing routing handler: %w", err)
}

41
pkg/provider/assertion.go Normal file
View File

@@ -0,0 +1,41 @@
package provider
import (
"fmt"
"time"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
)
func ClientAssertion(provider Provider, expiration time.Duration) (string, error) {
key := provider.GetClientConfiguration().GetClientJWK()
iat := time.Now()
exp := iat.Add(expiration)
errs := make([]error, 0)
tok := jwt.New()
errs = append(errs, tok.Set(jwt.IssuerKey, provider.GetClientConfiguration().GetClientID()))
errs = append(errs, tok.Set(jwt.SubjectKey, provider.GetClientConfiguration().GetClientID()))
errs = append(errs, tok.Set(jwt.AudienceKey, provider.GetOpenIDConfiguration().Issuer))
errs = append(errs, tok.Set("scope", provider.GetClientConfiguration().GetScopes().String()))
errs = append(errs, tok.Set(jwt.IssuedAtKey, iat))
errs = append(errs, tok.Set(jwt.ExpirationKey, exp))
errs = append(errs, tok.Set(jwt.JwtIDKey, uuid.New().String()))
for _, err := range errs {
if err != nil {
return "", fmt.Errorf("setting claim for client assertion: %w", err)
}
}
encoded, err := jwt.Sign(tok, jwa.SignatureAlgorithm(key.Algorithm()), key)
if err != nil {
return "", fmt.Errorf("signing client assertion: %w", err)
}
return string(encoded), nil
}

88
pkg/provider/provider.go Normal file
View File

@@ -0,0 +1,88 @@
package provider
import (
"context"
"fmt"
"github.com/lestrrat-go/jwx/jwk"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/openid"
)
type Provider interface {
GetClientConfiguration() openid.ClientConfiguration
GetOpenIDConfiguration() *openid.Configuration
GetPublicJwkSet() *jwk.Set
}
type provider struct {
clientConfiguration openid.ClientConfiguration
configuration *openid.Configuration
jwkSet *jwk.Set
}
func (p provider) GetClientConfiguration() openid.ClientConfiguration {
return p.clientConfiguration
}
func (p provider) GetOpenIDConfiguration() *openid.Configuration {
return p.configuration
}
func (p provider) GetPublicJwkSet() *jwk.Set {
return p.jwkSet
}
func NewProvider(cfg *config.Config) (Provider, error) {
clientJwkString := cfg.OpenID.ClientJWK
if len(clientJwkString) == 0 {
return nil, fmt.Errorf("missing required config %s", config.OpenIDClientJWK)
}
clientJwk, err := jwk.ParseKey([]byte(clientJwkString))
if err != nil {
return nil, fmt.Errorf("parsing client JWK: %w", err)
}
baseConfig := cfg.NewBaseConfig(clientJwk)
var clientConfig openid.ClientConfiguration
switch cfg.OpenID.Provider {
case "idporten":
clientConfig = baseConfig.IDPorten()
case "azure":
clientConfig = baseConfig.Azure()
case "":
return nil, fmt.Errorf("missing required config %s", config.OpenIDProvider)
default:
clientConfig = baseConfig
}
if len(clientConfig.GetClientID()) == 0 {
return nil, fmt.Errorf("missing required config %s", config.OpenIDClientID)
}
if len(clientConfig.GetWellKnownURL()) == 0 {
return nil, fmt.Errorf("missing required config %s", config.OpenIDWellKnownURL)
}
if len(clientConfig.GetRedirectURI()) == 0 {
return nil, fmt.Errorf("missing required config %s", config.OpenIDRedirectURI)
}
configuration, err := openid.FetchWellKnownConfig(clientConfig.GetWellKnownURL())
if err != nil {
return nil, fmt.Errorf("fetching well known config: %w", err)
}
jwkSet, err := configuration.FetchJwkSet(context.Background())
if err != nil {
return nil, fmt.Errorf("fetching jwk set: %w", err)
}
return &provider{
clientConfiguration: clientConfig,
configuration: configuration,
jwkSet: jwkSet,
}, nil
}

View File

@@ -1,14 +1,14 @@
package router
import (
"github.com/rs/zerolog"
"sync"
"github.com/lestrrat-go/jwx/jwk"
"github.com/rs/zerolog"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/cryptutil"
"github.com/nais/wonderwall/pkg/provider"
"github.com/nais/wonderwall/pkg/session"
)
@@ -16,10 +16,9 @@ type Handler struct {
Config config.Config
Crypter cryptutil.Crypter
OauthConfig oauth2.Config
Provider provider.Provider
SecureCookies bool
Sessions session.Store
UpstreamHost string
jwkSet jwk.Set
lock sync.Mutex
httplogger zerolog.Logger
}
@@ -28,30 +27,28 @@ func NewHandler(
cfg config.Config,
crypter cryptutil.Crypter,
httplogger zerolog.Logger,
jwkSet jwk.Set,
provider provider.Provider,
sessionStore session.Store,
upstreamHost string,
) (*Handler, error) {
oauthConfig := oauth2.Config{
ClientID: cfg.IDPorten.ClientID,
ClientID: provider.GetClientConfiguration().GetClientID(),
Endpoint: oauth2.Endpoint{
AuthURL: cfg.IDPorten.WellKnown.AuthorizationEndpoint,
TokenURL: cfg.IDPorten.WellKnown.TokenEndpoint,
AuthURL: provider.GetOpenIDConfiguration().AuthorizationEndpoint,
TokenURL: provider.GetOpenIDConfiguration().TokenEndpoint,
},
RedirectURL: cfg.IDPorten.RedirectURI,
Scopes: cfg.IDPorten.Scopes,
RedirectURL: provider.GetClientConfiguration().GetRedirectURI(),
Scopes: provider.GetClientConfiguration().GetScopes(),
}
return &Handler{
Config: cfg,
Crypter: crypter,
httplogger: httplogger,
jwkSet: jwkSet,
lock: sync.Mutex{},
OauthConfig: oauthConfig,
Provider: provider,
Sessions: sessionStore,
SecureCookies: true,
UpstreamHost: upstreamHost,
}, nil
}

View File

@@ -9,8 +9,8 @@ import (
"github.com/lestrrat-go/jwx/jwt"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/auth"
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/provider"
"github.com/nais/wonderwall/pkg/token"
)
@@ -40,7 +40,8 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
return
}
idToken, err := token.ParseIDToken(h.jwkSet, tokens)
jwkSet := h.Provider.GetPublicJwkSet()
idToken, err := token.ParseIDToken(*jwkSet, tokens)
if err != nil {
h.InternalError(w, r, fmt.Errorf("callback: parsing id_token: %w", err))
return
@@ -65,7 +66,7 @@ func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
}
func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *cookie.Login, code string) (*oauth2.Token, error) {
assertion, err := auth.ClientAssertion(h.Config.IDPorten, time.Second*30)
assertion, err := provider.ClientAssertion(h.Provider, time.Second*30)
if err != nil {
return nil, fmt.Errorf("creating client assertion: %w", err)
}
@@ -86,14 +87,14 @@ func (h *Handler) codeExchangeForToken(ctx context.Context, loginCookie *cookie.
func (h *Handler) validateIDToken(idToken *token.IDToken, loginCookie *cookie.Login) (string, error) {
validateOpts := []jwt.ValidateOption{
jwt.WithAudience(h.Config.IDPorten.ClientID),
jwt.WithAudience(h.Provider.GetClientConfiguration().GetClientID()),
jwt.WithClaimValue("nonce", loginCookie.Nonce),
jwt.WithIssuer(h.Config.IDPorten.WellKnown.Issuer),
jwt.WithIssuer(h.Provider.GetOpenIDConfiguration().Issuer),
jwt.WithAcceptableSkew(5 * time.Second),
jwt.WithRequiredClaim("sid"),
}
if h.Config.IDPorten.SecurityLevel.Enabled {
if h.Provider.GetClientConfiguration().GetACRValues().Enabled {
validateOpts = append(validateOpts, jwt.WithRequiredClaim("acr"))
}

View File

@@ -2,14 +2,15 @@ package router
import (
"fmt"
"github.com/nais/wonderwall/pkg/request"
"net/http"
"net/url"
"github.com/nais/wonderwall/pkg/request"
)
// Logout triggers self-initiated for the current user
func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
u, err := url.Parse(h.Config.IDPorten.WellKnown.EndSessionEndpoint)
u, err := url.Parse(h.Provider.GetOpenIDConfiguration().EndSessionEndpoint)
if err != nil {
h.InternalError(w, r, fmt.Errorf("logout: parsing end session endpoint: %w", err))
return
@@ -30,7 +31,11 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
h.deleteCookie(w, h.GetSessionCookieName())
v := u.Query()
v.Add("post_logout_redirect_uri", request.PostLogoutRedirectURI(r, h.Config.IDPorten.PostLogoutRedirectURI))
postLogoutURI := request.PostLogoutRedirectURI(r, h.Provider.GetClientConfiguration().GetPostLogoutRedirectURI())
if len(postLogoutURI) > 0 {
v.Add("post_logout_redirect_uri", postLogoutURI)
}
if len(idToken) != 0 {
v.Add("id_token_hint", idToken)

View File

@@ -23,9 +23,9 @@ func (h *Handler) LoginURL(r *http.Request, params *openid.Parameters) (string,
v := u.Query()
v.Add("response_type", "code")
v.Add("client_id", h.Config.IDPorten.ClientID)
v.Add("redirect_uri", h.Config.IDPorten.RedirectURI)
v.Add("scope", token.ScopeOpenID)
v.Add("client_id", h.Provider.GetClientConfiguration().GetClientID())
v.Add("redirect_uri", h.Provider.GetClientConfiguration().GetRedirectURI())
v.Add("scope", h.Provider.GetClientConfiguration().GetScopes().String())
v.Add("state", params.State)
v.Add("nonce", params.Nonce)
v.Add("response_mode", "query")
@@ -48,12 +48,12 @@ func (h *Handler) LoginURL(r *http.Request, params *openid.Parameters) (string,
}
func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error {
if !h.Config.IDPorten.SecurityLevel.Enabled {
if !h.Provider.GetClientConfiguration().GetACRValues().Enabled {
return nil
}
fallback := h.Config.IDPorten.SecurityLevel.Value
supported := h.Config.IDPorten.WellKnown.ACRValuesSupported
fallback := h.Provider.GetClientConfiguration().GetACRValues().Value
supported := h.Provider.GetOpenIDConfiguration().ACRValuesSupported
securityLevel, err := request.LoginURLParameter(r, request.SecurityLevelURLParameter, fallback, supported)
if err != nil {
@@ -65,12 +65,12 @@ func (h *Handler) withSecurityLevel(r *http.Request, v url.Values) error {
}
func (h *Handler) withLocale(r *http.Request, v url.Values) error {
if !h.Config.IDPorten.Locale.Enabled {
if !h.Provider.GetClientConfiguration().GetUILocales().Enabled {
return nil
}
fallback := h.Config.IDPorten.Locale.Value
supported := h.Config.IDPorten.WellKnown.UILocalesSupported
fallback := h.Provider.GetClientConfiguration().GetUILocales().Value
supported := h.Provider.GetOpenIDConfiguration().UILocalesSupported
locale, err := request.LoginURLParameter(r, request.LocaleURLParameter, fallback, supported)
if err != nil {

View File

@@ -21,7 +21,7 @@ import (
// Thus, we cannot assume that the value of `sid` to uniquely identify the pair of (user, application session)
// if using a shared session store.
func (h *Handler) localSessionID(sid string) string {
return fmt.Sprintf("%s-%s", h.Config.IDPorten.ClientID, sid)
return fmt.Sprintf("%s:%s:%s", h.Config.OpenID.Provider, h.Provider.GetClientConfiguration().GetClientID(), sid)
}
func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (*session.Data, error) {