refactor(openid): clean up client and provider

This commit is contained in:
Trong Huu Nguyen
2022-09-02 18:08:36 +02:00
parent 92ee6313c5
commit 08eefbf1d5
16 changed files with 62 additions and 78 deletions

View File

@@ -40,12 +40,12 @@ func run() error {
cookieOpts := cookie.DefaultOptions()
openidProvider, err := provider.NewProvider(ctx, openidConfig)
jwksProvider, err := provider.NewJwksProvider(ctx, openidConfig)
if err != nil {
return err
}
h, err := handler.NewHandler(cfg, cookieOpts, openidConfig, openidProvider, crypt)
h, err := handler.NewHandler(cfg, cookieOpts, jwksProvider, openidConfig, crypt)
if err != nil {
return fmt.Errorf("initializing routing handler: %w", err)
}

View File

@@ -12,7 +12,6 @@ import (
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/crypto"
errorhandler "github.com/nais/wonderwall/pkg/handler/error"
"github.com/nais/wonderwall/pkg/loginstatus"
logentry "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/openid"
openidclient "github.com/nais/wonderwall/pkg/openid/client"
@@ -27,11 +26,10 @@ type Source interface {
GetCookieOptsPathAware(r *http.Request) cookie.Options
GetCrypter() crypto.Crypter
GetErrorHandler() errorhandler.Handler
GetLoginstatus() *loginstatus.Loginstatus
}
func Handler(src Source, w http.ResponseWriter, r *http.Request) {
login, err := src.GetClient().Login(r, src.GetLoginstatus())
login, err := src.GetClient().Login(r)
if err != nil {
if errors.Is(err, openidclient.InvalidSecurityLevelError) || errors.Is(err, openidclient.InvalidLocaleError) {
src.GetErrorHandler().BadRequest(w, r, err)

View File

@@ -29,7 +29,6 @@ type Source interface {
GetCrypter() crypto.Crypter
GetErrorHandler() errorhandler.Handler
GetLoginstatus() *loginstatus.Loginstatus
GetProvider() openidclient.OpenIDProvider
GetSessions() *session.Handler
GetSessionConfig() config.Session
}
@@ -48,7 +47,7 @@ func Handler(src Source, w http.ResponseWriter, r *http.Request) {
return
}
loginCallback, err := src.GetClient().LoginCallback(r, src.GetProvider(), loginCookie)
loginCallback, err := src.GetClient().LoginCallback(r, loginCookie)
if err != nil {
src.GetErrorHandler().InternalError(w, r, err)
return

View File

@@ -19,8 +19,8 @@ import (
func NewHandler(
cfg *config.Config,
cookieOpts cookie.Options,
jwksProvider client.JwksProvider,
openidConfig openidconfig.Config,
openidProvider client.OpenIDProvider,
crypter crypto.Crypter,
) (*StandardHandler, error) {
autoLogin, err := autologin.New(cfg)
@@ -32,7 +32,9 @@ func NewHandler(
Timeout: time.Second * 10,
}
openidClient := client.NewClient(openidConfig)
loginstatusClient := loginstatus.NewClient(cfg.Loginstatus, httpClient)
openidClient := client.NewClient(openidConfig, loginstatusClient, jwksProvider)
openidClient.SetHttpClient(httpClient)
sessionHandler, err := session.NewHandler(cfg, openidConfig, crypter, openidClient)
@@ -52,9 +54,8 @@ func NewHandler(
cookieOptions: cookieOpts,
crypter: crypter,
ingresses: ingresses,
loginstatus: loginstatus.NewClient(cfg.Loginstatus, httpClient),
loginstatus: loginstatusClient,
openidConfig: openidConfig,
provider: openidProvider,
sessions: sessionHandler,
upstreamProxy: reverseproxy.New(cfg.UpstreamHost),
}, nil

View File

@@ -36,7 +36,6 @@ type StandardHandler struct {
ingresses *ingress.Ingresses
loginstatus *loginstatus.Loginstatus
openidConfig openidconfig.Config
provider openidclient.OpenIDProvider
sessions *session.Handler
upstreamProxy *reverseproxy.ReverseProxy
}
@@ -91,10 +90,6 @@ func (s *StandardHandler) GetPath(r *http.Request) string {
return path
}
func (s *StandardHandler) GetProvider() openidclient.OpenIDProvider {
return s.provider
}
func (s *StandardHandler) GetProviderName() string {
return s.openidConfig.Provider().Name()
}

View File

@@ -31,7 +31,6 @@ import (
type IdentityProvider struct {
Cfg *config.Config
OpenIDConfig *TestConfiguration
Provider *TestProvider
ProviderHandler *IdentityProviderHandler
ProviderServer *httptest.Server
RelyingPartyHandler *handlerpkg.StandardHandler
@@ -75,8 +74,8 @@ func (in *IdentityProvider) GetRequest(target string) *http.Request {
func NewIdentityProvider(cfg *config.Config) *IdentityProvider {
openidConfig := NewTestConfiguration(cfg)
provider := newTestProvider()
handler := newIdentityProviderHandler(provider, openidConfig)
jwksProvider := NewTestJwksProvider()
handler := newIdentityProviderHandler(jwksProvider, openidConfig)
idpRouter := identityProviderRouter(handler)
server := httptest.NewServer(idpRouter)
@@ -89,7 +88,7 @@ func NewIdentityProvider(cfg *config.Config) *IdentityProvider {
crypter := crypto.NewCrypter([]byte(cfg.EncryptionKey))
cookieOpts := cookie.DefaultOptions().WithSecure(false)
rpHandler, err := handlerpkg.NewHandler(cfg, cookieOpts, openidConfig, provider, crypter)
rpHandler, err := handlerpkg.NewHandler(cfg, cookieOpts, jwksProvider, openidConfig, crypter)
if err != nil {
panic(err)
}
@@ -102,7 +101,6 @@ func NewIdentityProvider(cfg *config.Config) *IdentityProvider {
RelyingPartyHandler: rpHandler,
RelyingPartyServer: rpServer,
OpenIDConfig: openidConfig,
Provider: provider,
ProviderHandler: handler,
ProviderServer: server,
}

View File

@@ -28,7 +28,7 @@ func (p *TestProvider) PrivateJwkSet() *jwk.Set {
return &p.JwksPair.Private
}
func newTestProvider() *TestProvider {
func NewTestJwksProvider() *TestProvider {
jwksPair, err := crypto.NewJwkSet()
if err != nil {
log.Fatal(err)

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt"
"golang.org/x/oauth2"
@@ -19,13 +20,20 @@ import (
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
)
type JwksProvider interface {
GetPublicJwkSet(ctx context.Context) (*jwk.Set, error)
RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error)
}
type Client struct {
cfg openidconfig.Config
httpClient *http.Client
jwksProvider JwksProvider
loginstatus *loginstatus.Loginstatus
oauth2Config *oauth2.Config
}
func NewClient(cfg openidconfig.Config) *Client {
func NewClient(cfg openidconfig.Config, loginstatus *loginstatus.Loginstatus, jwksProvider JwksProvider) *Client {
oauth2Config := &oauth2.Config{
ClientID: cfg.Client().ClientID(),
Endpoint: oauth2.Endpoint{
@@ -39,24 +47,18 @@ func NewClient(cfg openidconfig.Config) *Client {
return &Client{
cfg: cfg,
httpClient: http.DefaultClient,
jwksProvider: jwksProvider,
loginstatus: loginstatus,
oauth2Config: oauth2Config,
}
}
func (c *Client) config() openidconfig.Config {
return c.cfg
}
func (c *Client) oAuth2Config() *oauth2.Config {
return c.oauth2Config
}
func (c *Client) SetHttpClient(httpClient *http.Client) {
c.httpClient = httpClient
}
func (c *Client) Login(r *http.Request, loginstatus *loginstatus.Loginstatus) (*Login, error) {
login, err := NewLogin(c, r, loginstatus)
func (c *Client) Login(r *http.Request) (*Login, error) {
login, err := NewLogin(c, r)
if err != nil {
return nil, fmt.Errorf("login: %w", err)
}
@@ -64,8 +66,8 @@ func (c *Client) Login(r *http.Request, loginstatus *loginstatus.Loginstatus) (*
return login, nil
}
func (c *Client) LoginCallback(r *http.Request, p OpenIDProvider, cookie *openid.LoginCookie) (*LoginCallback, error) {
loginCallback, err := NewLoginCallback(c, r, p, cookie)
func (c *Client) LoginCallback(r *http.Request, cookie *openid.LoginCookie) (*LoginCallback, error) {
loginCallback, err := NewLoginCallback(c, r, cookie)
if err != nil {
return nil, fmt.Errorf("callback: %w", err)
}
@@ -95,8 +97,8 @@ func (c *Client) AuthCodeGrant(ctx context.Context, code string, opts []oauth2.A
}
func (c *Client) MakeAssertion(expiration time.Duration) (string, error) {
clientCfg := c.config().Client()
providerCfg := c.config().Provider()
clientCfg := c.cfg.Client()
providerCfg := c.cfg.Provider()
key := clientCfg.ClientJWK()
iat := time.Now().Truncate(time.Second)
@@ -135,11 +137,11 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid
v := url.Values{}
v.Set(openid.GrantType, openid.RefreshTokenValue)
v.Set(openid.RefreshToken, refreshToken)
v.Set(openid.ClientID, c.config().Client().ClientID())
v.Set(openid.ClientID, c.cfg.Client().ClientID())
v.Set(openid.ClientAssertion, assertion)
v.Set(openid.ClientAssertionType, openid.ClientAssertionTypeJwtBearer)
r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.config().Provider().TokenEndpoint(), strings.NewReader(v.Encode()))
r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().TokenEndpoint(), strings.NewReader(v.Encode()))
if err != nil {
return nil, fmt.Errorf("creating request: %w", err)
}

View File

@@ -17,7 +17,7 @@ func TestMakeAssertion(t *testing.T) {
openidConfig := mock.NewTestConfiguration(cfg)
openidConfig.TestProvider.SetIssuer("some-issuer")
c := client.NewClient(openidConfig)
c := newTestClientWithConfig(openidConfig)
expiry := 30 * time.Second
assertionString, err := c.MakeAssertion(expiry)
@@ -46,7 +46,8 @@ func TestMakeAssertion(t *testing.T) {
}
func newTestClientWithConfig(config *mock.TestConfiguration) *client.Client {
return client.NewClient(config)
jwksProvider := mock.NewTestJwksProvider()
return client.NewClient(config, nil, jwksProvider)
}
func newTestClient() *client.Client {

View File

@@ -37,7 +37,7 @@ var (
}
)
func NewLogin(c *Client, r *http.Request, loginstatus *loginstatus.Loginstatus) (*Login, error) {
func NewLogin(c *Client, r *http.Request) (*Login, error) {
params, err := newLoginParameters(c)
if err != nil {
return nil, fmt.Errorf("generating parameters: %w", err)
@@ -48,7 +48,7 @@ func NewLogin(c *Client, r *http.Request, loginstatus *loginstatus.Loginstatus)
return nil, fmt.Errorf("generating callback url: %w", err)
}
url, err := params.authCodeURL(r, callbackURL, loginstatus)
url, err := params.authCodeURL(r, callbackURL, c.loginstatus)
if err != nil {
return nil, fmt.Errorf("generating auth code url: %w", err)
}
@@ -156,7 +156,7 @@ func (in *loginParameters) authCodeURL(r *http.Request, callbackURL string, logi
return "", fmt.Errorf("%w: %+v", InvalidLocaleError, err)
}
authCodeUrl := in.oAuth2Config().AuthCodeURL(in.State, opts...)
authCodeUrl := in.oauth2Config.AuthCodeURL(in.State, opts...)
return authCodeUrl, nil
}
@@ -174,8 +174,8 @@ func (in *loginParameters) withLocale(r *http.Request, opts []oauth2.AuthCodeOpt
return withParamMapping(r,
opts,
LocaleURLParameter,
in.config().Client().UILocales(),
in.config().Provider().UILocalesSupported(),
in.cfg.Client().UILocales(),
in.cfg.Provider().UILocalesSupported(),
)
}
@@ -183,8 +183,8 @@ func (in *loginParameters) withSecurityLevel(r *http.Request, opts []oauth2.Auth
return withParamMapping(r,
opts,
SecurityLevelURLParameter,
in.config().Client().ACRValues(),
in.config().Provider().ACRValuesSupported(),
in.cfg.Client().ACRValues(),
in.cfg.Provider().ACRValuesSupported(),
)
}

View File

@@ -7,34 +7,26 @@ import (
"net/url"
"time"
"github.com/lestrrat-go/jwx/v2/jwk"
"golang.org/x/oauth2"
"github.com/nais/wonderwall/pkg/openid"
)
type OpenIDProvider interface {
GetPublicJwkSet(ctx context.Context) (*jwk.Set, error)
RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error)
}
type LoginCallback struct {
client *Client
*Client
cookie *openid.LoginCookie
provider OpenIDProvider
request *http.Request
requestParams url.Values
}
func NewLoginCallback(c *Client, r *http.Request, p OpenIDProvider, cookie *openid.LoginCookie) (*LoginCallback, error) {
func NewLoginCallback(c *Client, r *http.Request, cookie *openid.LoginCookie) (*LoginCallback, error) {
if cookie == nil {
return nil, fmt.Errorf("cookie is nil")
}
return &LoginCallback{
client: c,
Client: c,
cookie: cookie,
provider: p,
request: r,
requestParams: r.URL.Query(),
}, nil
@@ -66,7 +58,7 @@ func (in *LoginCallback) StateMismatchError() error {
}
func (in *LoginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, error) {
clientAssertion, err := in.client.MakeAssertion(time.Second * 30)
clientAssertion, err := in.MakeAssertion(time.Second * 30)
if err != nil {
return nil, fmt.Errorf("creating client assertion: %w", err)
}
@@ -79,12 +71,12 @@ func (in *LoginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, erro
}
code := in.requestParams.Get(openid.Code)
rawTokens, err := in.client.AuthCodeGrant(ctx, code, opts)
rawTokens, err := in.AuthCodeGrant(ctx, code, opts)
if err != nil {
return nil, fmt.Errorf("exchanging authorization code for token: %w", err)
}
jwkSet, err := in.provider.GetPublicJwkSet(ctx)
jwkSet, err := in.jwksProvider.GetPublicJwkSet(ctx)
if err != nil {
return nil, fmt.Errorf("getting jwks: %w", err)
}
@@ -92,11 +84,11 @@ func (in *LoginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, erro
tokens, err := openid.NewTokens(rawTokens, *jwkSet)
if err != nil {
// JWKS might not be up-to-date, so we'll want to force a refresh for the next attempt
_, _ = in.provider.RefreshPublicJwkSet(ctx)
_, _ = in.jwksProvider.RefreshPublicJwkSet(ctx)
return nil, fmt.Errorf("parsing tokens: %w", err)
}
err = tokens.IDToken.Validate(in.client.config(), in.cookie.Nonce)
err = tokens.IDToken.Validate(in.cfg, in.cookie.Nonce)
if err != nil {
return nil, fmt.Errorf("validating id_token: %w", err)
}

View File

@@ -115,8 +115,6 @@ func newLoginCallback(t *testing.T, url string) (*mock.IdentityProvider, *client
idp.SetIngresses(mock.Ingress)
req := idp.GetRequest(url)
cfg := idp.OpenIDConfig
redirect, err := urlpkg.LoginCallbackURL(req)
assert.NoError(t, err)
@@ -136,7 +134,7 @@ func newLoginCallback(t *testing.T, url string) (*mock.IdentityProvider, *client
RedirectURI: redirect,
}
loginCallback, err := newTestClientWithConfig(cfg).LoginCallback(req, idp.Provider, cookie)
loginCallback, err := idp.RelyingPartyHandler.GetClient().LoginCallback(req, cookie)
assert.NoError(t, err)
return idp, loginCallback

View File

@@ -66,11 +66,11 @@ func TestLogin_URL(t *testing.T) {
openidConfig := mock.NewTestConfiguration(cfg)
ingresses := mock.Ingresses(cfg)
c := client.NewClient(openidConfig)
lsc := loginstatus.NewClient(cfg.Loginstatus, http.DefaultClient)
c := client.NewClient(openidConfig, lsc, nil)
req := mock.NewGetRequest(test.url, ingresses)
result, err := c.Login(req, lsc)
result, err := c.Login(req)
if test.error != nil {
assert.True(t, errors.Is(err, test.error))
@@ -126,12 +126,12 @@ func TestLoginURL_WithResourceIndicator(t *testing.T) {
openidConfig := mock.NewTestConfiguration(cfg)
openidConfig.TestProvider.SetAuthorizationEndpoint("https://provider/authorize")
c := client.NewClient(openidConfig)
c := client.NewClient(openidConfig, lsc, nil)
ingresses := mock.Ingresses(cfg)
req := mock.NewGetRequest(mock.Ingress+"/oauth2/login", ingresses)
result, err := c.Login(req, lsc)
result, err := c.Login(req)
assert.NoError(t, err)
assert.NotEmpty(t, result)
parsed, err := url.Parse(result.AuthCodeURL())

View File

@@ -28,7 +28,7 @@ func NewLogout(c *Client, r *http.Request) (*Logout, error) {
}
func (in *Logout) SingleLogoutURL(idToken string) string {
endSessionEndpoint := in.config().Provider().EndSessionEndpointURL()
endSessionEndpoint := in.cfg.Provider().EndSessionEndpointURL()
v := endSessionEndpoint.Query()
v.Add(openid.PostLogoutRedirectURI, in.logoutCallbackURL)

View File

@@ -19,7 +19,7 @@ func NewLogoutCallback(c *Client, r *http.Request) *LogoutCallback {
}
func (in *LogoutCallback) PostLogoutRedirectURI() string {
redirect := in.config().Client().PostLogoutRedirectURI()
redirect := in.cfg.Client().PostLogoutRedirectURI()
if len(redirect) > 0 {
return redirect

View File

@@ -15,7 +15,7 @@ const (
JwkMinimumRefreshInterval = 5 * time.Second
)
type Provider struct {
type JwksProvider struct {
config openidconfig.Provider
jwksCache *jwk.Cache
jwksLock *jwksLock
@@ -26,7 +26,7 @@ type jwksLock struct {
sync.Mutex
}
func (p *Provider) GetPublicJwkSet(ctx context.Context) (*jwk.Set, error) {
func (p *JwksProvider) GetPublicJwkSet(ctx context.Context) (*jwk.Set, error) {
url := p.config.JwksURI()
set, err := p.jwksCache.Get(ctx, url)
if err != nil {
@@ -36,7 +36,7 @@ func (p *Provider) GetPublicJwkSet(ctx context.Context) (*jwk.Set, error) {
return &set, nil
}
func (p *Provider) RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error) {
func (p *JwksProvider) RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error) {
p.jwksLock.Lock()
defer p.jwksLock.Unlock()
@@ -57,7 +57,7 @@ func (p *Provider) RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error) {
return &set, nil
}
func NewProvider(ctx context.Context, openidCfg openidconfig.Config) (*Provider, error) {
func NewJwksProvider(ctx context.Context, openidCfg openidconfig.Config) (*JwksProvider, error) {
providerCfg := openidCfg.Provider()
uri := providerCfg.JwksURI()
@@ -74,7 +74,7 @@ func NewProvider(ctx context.Context, openidCfg openidconfig.Config) (*Provider,
return nil, fmt.Errorf("initial fetch of jwks from provider: %w", err)
}
return &Provider{
return &JwksProvider{
config: providerCfg,
jwksCache: cache,
jwksLock: &jwksLock{},