diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index b1f610e..79a38c3 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -40,12 +40,12 @@ func run() error { cookieOpts := cookie.DefaultOptions() - openidProvider, err := provider.NewProvider(ctx, openidConfig) + jwksProvider, err := provider.NewJwksProvider(ctx, openidConfig) if err != nil { return err } - h, err := handler.NewHandler(cfg, cookieOpts, openidConfig, openidProvider, crypt) + h, err := handler.NewHandler(cfg, cookieOpts, jwksProvider, openidConfig, crypt) if err != nil { return fmt.Errorf("initializing routing handler: %w", err) } diff --git a/pkg/handler/api/login/login.go b/pkg/handler/api/login/login.go index 7e7a841..db3c768 100644 --- a/pkg/handler/api/login/login.go +++ b/pkg/handler/api/login/login.go @@ -12,7 +12,6 @@ import ( "github.com/nais/wonderwall/pkg/cookie" "github.com/nais/wonderwall/pkg/crypto" errorhandler "github.com/nais/wonderwall/pkg/handler/error" - "github.com/nais/wonderwall/pkg/loginstatus" logentry "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/openid" openidclient "github.com/nais/wonderwall/pkg/openid/client" @@ -27,11 +26,10 @@ type Source interface { GetCookieOptsPathAware(r *http.Request) cookie.Options GetCrypter() crypto.Crypter GetErrorHandler() errorhandler.Handler - GetLoginstatus() *loginstatus.Loginstatus } func Handler(src Source, w http.ResponseWriter, r *http.Request) { - login, err := src.GetClient().Login(r, src.GetLoginstatus()) + login, err := src.GetClient().Login(r) if err != nil { if errors.Is(err, openidclient.InvalidSecurityLevelError) || errors.Is(err, openidclient.InvalidLocaleError) { src.GetErrorHandler().BadRequest(w, r, err) diff --git a/pkg/handler/api/logincallback/logincallback.go b/pkg/handler/api/logincallback/logincallback.go index 650ff61..0abce11 100644 --- a/pkg/handler/api/logincallback/logincallback.go +++ b/pkg/handler/api/logincallback/logincallback.go @@ -29,7 +29,6 @@ type Source interface { GetCrypter() crypto.Crypter GetErrorHandler() errorhandler.Handler GetLoginstatus() *loginstatus.Loginstatus - GetProvider() openidclient.OpenIDProvider GetSessions() *session.Handler GetSessionConfig() config.Session } @@ -48,7 +47,7 @@ func Handler(src Source, w http.ResponseWriter, r *http.Request) { return } - loginCallback, err := src.GetClient().LoginCallback(r, src.GetProvider(), loginCookie) + loginCallback, err := src.GetClient().LoginCallback(r, loginCookie) if err != nil { src.GetErrorHandler().InternalError(w, r, err) return diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 0919eb4..d6f3ade 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -19,8 +19,8 @@ import ( func NewHandler( cfg *config.Config, cookieOpts cookie.Options, + jwksProvider client.JwksProvider, openidConfig openidconfig.Config, - openidProvider client.OpenIDProvider, crypter crypto.Crypter, ) (*StandardHandler, error) { autoLogin, err := autologin.New(cfg) @@ -32,7 +32,9 @@ func NewHandler( Timeout: time.Second * 10, } - openidClient := client.NewClient(openidConfig) + loginstatusClient := loginstatus.NewClient(cfg.Loginstatus, httpClient) + + openidClient := client.NewClient(openidConfig, loginstatusClient, jwksProvider) openidClient.SetHttpClient(httpClient) sessionHandler, err := session.NewHandler(cfg, openidConfig, crypter, openidClient) @@ -52,9 +54,8 @@ func NewHandler( cookieOptions: cookieOpts, crypter: crypter, ingresses: ingresses, - loginstatus: loginstatus.NewClient(cfg.Loginstatus, httpClient), + loginstatus: loginstatusClient, openidConfig: openidConfig, - provider: openidProvider, sessions: sessionHandler, upstreamProxy: reverseproxy.New(cfg.UpstreamHost), }, nil diff --git a/pkg/handler/handler_standard.go b/pkg/handler/handler_standard.go index 2627185..43aa38d 100644 --- a/pkg/handler/handler_standard.go +++ b/pkg/handler/handler_standard.go @@ -36,7 +36,6 @@ type StandardHandler struct { ingresses *ingress.Ingresses loginstatus *loginstatus.Loginstatus openidConfig openidconfig.Config - provider openidclient.OpenIDProvider sessions *session.Handler upstreamProxy *reverseproxy.ReverseProxy } @@ -91,10 +90,6 @@ func (s *StandardHandler) GetPath(r *http.Request) string { return path } -func (s *StandardHandler) GetProvider() openidclient.OpenIDProvider { - return s.provider -} - func (s *StandardHandler) GetProviderName() string { return s.openidConfig.Provider().Name() } diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index 2755117..4f30d21 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -31,7 +31,6 @@ import ( type IdentityProvider struct { Cfg *config.Config OpenIDConfig *TestConfiguration - Provider *TestProvider ProviderHandler *IdentityProviderHandler ProviderServer *httptest.Server RelyingPartyHandler *handlerpkg.StandardHandler @@ -75,8 +74,8 @@ func (in *IdentityProvider) GetRequest(target string) *http.Request { func NewIdentityProvider(cfg *config.Config) *IdentityProvider { openidConfig := NewTestConfiguration(cfg) - provider := newTestProvider() - handler := newIdentityProviderHandler(provider, openidConfig) + jwksProvider := NewTestJwksProvider() + handler := newIdentityProviderHandler(jwksProvider, openidConfig) idpRouter := identityProviderRouter(handler) server := httptest.NewServer(idpRouter) @@ -89,7 +88,7 @@ func NewIdentityProvider(cfg *config.Config) *IdentityProvider { crypter := crypto.NewCrypter([]byte(cfg.EncryptionKey)) cookieOpts := cookie.DefaultOptions().WithSecure(false) - rpHandler, err := handlerpkg.NewHandler(cfg, cookieOpts, openidConfig, provider, crypter) + rpHandler, err := handlerpkg.NewHandler(cfg, cookieOpts, jwksProvider, openidConfig, crypter) if err != nil { panic(err) } @@ -102,7 +101,6 @@ func NewIdentityProvider(cfg *config.Config) *IdentityProvider { RelyingPartyHandler: rpHandler, RelyingPartyServer: rpServer, OpenIDConfig: openidConfig, - Provider: provider, ProviderHandler: handler, ProviderServer: server, } diff --git a/pkg/mock/provider.go b/pkg/mock/provider.go index 3e17f1c..5ffe832 100644 --- a/pkg/mock/provider.go +++ b/pkg/mock/provider.go @@ -28,7 +28,7 @@ func (p *TestProvider) PrivateJwkSet() *jwk.Set { return &p.JwksPair.Private } -func newTestProvider() *TestProvider { +func NewTestJwksProvider() *TestProvider { jwksPair, err := crypto.NewJwkSet() if err != nil { log.Fatal(err) diff --git a/pkg/openid/client/client.go b/pkg/openid/client/client.go index 84b13be..f3c7596 100644 --- a/pkg/openid/client/client.go +++ b/pkg/openid/client/client.go @@ -11,6 +11,7 @@ import ( "time" "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" "golang.org/x/oauth2" @@ -19,13 +20,20 @@ import ( openidconfig "github.com/nais/wonderwall/pkg/openid/config" ) +type JwksProvider interface { + GetPublicJwkSet(ctx context.Context) (*jwk.Set, error) + RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error) +} + type Client struct { cfg openidconfig.Config httpClient *http.Client + jwksProvider JwksProvider + loginstatus *loginstatus.Loginstatus oauth2Config *oauth2.Config } -func NewClient(cfg openidconfig.Config) *Client { +func NewClient(cfg openidconfig.Config, loginstatus *loginstatus.Loginstatus, jwksProvider JwksProvider) *Client { oauth2Config := &oauth2.Config{ ClientID: cfg.Client().ClientID(), Endpoint: oauth2.Endpoint{ @@ -39,24 +47,18 @@ func NewClient(cfg openidconfig.Config) *Client { return &Client{ cfg: cfg, httpClient: http.DefaultClient, + jwksProvider: jwksProvider, + loginstatus: loginstatus, oauth2Config: oauth2Config, } } -func (c *Client) config() openidconfig.Config { - return c.cfg -} - -func (c *Client) oAuth2Config() *oauth2.Config { - return c.oauth2Config -} - func (c *Client) SetHttpClient(httpClient *http.Client) { c.httpClient = httpClient } -func (c *Client) Login(r *http.Request, loginstatus *loginstatus.Loginstatus) (*Login, error) { - login, err := NewLogin(c, r, loginstatus) +func (c *Client) Login(r *http.Request) (*Login, error) { + login, err := NewLogin(c, r) if err != nil { return nil, fmt.Errorf("login: %w", err) } @@ -64,8 +66,8 @@ func (c *Client) Login(r *http.Request, loginstatus *loginstatus.Loginstatus) (* return login, nil } -func (c *Client) LoginCallback(r *http.Request, p OpenIDProvider, cookie *openid.LoginCookie) (*LoginCallback, error) { - loginCallback, err := NewLoginCallback(c, r, p, cookie) +func (c *Client) LoginCallback(r *http.Request, cookie *openid.LoginCookie) (*LoginCallback, error) { + loginCallback, err := NewLoginCallback(c, r, cookie) if err != nil { return nil, fmt.Errorf("callback: %w", err) } @@ -95,8 +97,8 @@ func (c *Client) AuthCodeGrant(ctx context.Context, code string, opts []oauth2.A } func (c *Client) MakeAssertion(expiration time.Duration) (string, error) { - clientCfg := c.config().Client() - providerCfg := c.config().Provider() + clientCfg := c.cfg.Client() + providerCfg := c.cfg.Provider() key := clientCfg.ClientJWK() iat := time.Now().Truncate(time.Second) @@ -135,11 +137,11 @@ func (c *Client) RefreshGrant(ctx context.Context, refreshToken string) (*openid v := url.Values{} v.Set(openid.GrantType, openid.RefreshTokenValue) v.Set(openid.RefreshToken, refreshToken) - v.Set(openid.ClientID, c.config().Client().ClientID()) + v.Set(openid.ClientID, c.cfg.Client().ClientID()) v.Set(openid.ClientAssertion, assertion) v.Set(openid.ClientAssertionType, openid.ClientAssertionTypeJwtBearer) - r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.config().Provider().TokenEndpoint(), strings.NewReader(v.Encode())) + r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().TokenEndpoint(), strings.NewReader(v.Encode())) if err != nil { return nil, fmt.Errorf("creating request: %w", err) } diff --git a/pkg/openid/client/client_test.go b/pkg/openid/client/client_test.go index 3a47dc6..e0b3a36 100644 --- a/pkg/openid/client/client_test.go +++ b/pkg/openid/client/client_test.go @@ -17,7 +17,7 @@ func TestMakeAssertion(t *testing.T) { openidConfig := mock.NewTestConfiguration(cfg) openidConfig.TestProvider.SetIssuer("some-issuer") - c := client.NewClient(openidConfig) + c := newTestClientWithConfig(openidConfig) expiry := 30 * time.Second assertionString, err := c.MakeAssertion(expiry) @@ -46,7 +46,8 @@ func TestMakeAssertion(t *testing.T) { } func newTestClientWithConfig(config *mock.TestConfiguration) *client.Client { - return client.NewClient(config) + jwksProvider := mock.NewTestJwksProvider() + return client.NewClient(config, nil, jwksProvider) } func newTestClient() *client.Client { diff --git a/pkg/openid/client/login.go b/pkg/openid/client/login.go index d2467d6..d3ea205 100644 --- a/pkg/openid/client/login.go +++ b/pkg/openid/client/login.go @@ -37,7 +37,7 @@ var ( } ) -func NewLogin(c *Client, r *http.Request, loginstatus *loginstatus.Loginstatus) (*Login, error) { +func NewLogin(c *Client, r *http.Request) (*Login, error) { params, err := newLoginParameters(c) if err != nil { return nil, fmt.Errorf("generating parameters: %w", err) @@ -48,7 +48,7 @@ func NewLogin(c *Client, r *http.Request, loginstatus *loginstatus.Loginstatus) return nil, fmt.Errorf("generating callback url: %w", err) } - url, err := params.authCodeURL(r, callbackURL, loginstatus) + url, err := params.authCodeURL(r, callbackURL, c.loginstatus) if err != nil { return nil, fmt.Errorf("generating auth code url: %w", err) } @@ -156,7 +156,7 @@ func (in *loginParameters) authCodeURL(r *http.Request, callbackURL string, logi return "", fmt.Errorf("%w: %+v", InvalidLocaleError, err) } - authCodeUrl := in.oAuth2Config().AuthCodeURL(in.State, opts...) + authCodeUrl := in.oauth2Config.AuthCodeURL(in.State, opts...) return authCodeUrl, nil } @@ -174,8 +174,8 @@ func (in *loginParameters) withLocale(r *http.Request, opts []oauth2.AuthCodeOpt return withParamMapping(r, opts, LocaleURLParameter, - in.config().Client().UILocales(), - in.config().Provider().UILocalesSupported(), + in.cfg.Client().UILocales(), + in.cfg.Provider().UILocalesSupported(), ) } @@ -183,8 +183,8 @@ func (in *loginParameters) withSecurityLevel(r *http.Request, opts []oauth2.Auth return withParamMapping(r, opts, SecurityLevelURLParameter, - in.config().Client().ACRValues(), - in.config().Provider().ACRValuesSupported(), + in.cfg.Client().ACRValues(), + in.cfg.Provider().ACRValuesSupported(), ) } diff --git a/pkg/openid/client/login_callback.go b/pkg/openid/client/login_callback.go index f3ecd20..dd03b9d 100644 --- a/pkg/openid/client/login_callback.go +++ b/pkg/openid/client/login_callback.go @@ -7,34 +7,26 @@ import ( "net/url" "time" - "github.com/lestrrat-go/jwx/v2/jwk" "golang.org/x/oauth2" "github.com/nais/wonderwall/pkg/openid" ) -type OpenIDProvider interface { - GetPublicJwkSet(ctx context.Context) (*jwk.Set, error) - RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error) -} - type LoginCallback struct { - client *Client + *Client cookie *openid.LoginCookie - provider OpenIDProvider request *http.Request requestParams url.Values } -func NewLoginCallback(c *Client, r *http.Request, p OpenIDProvider, cookie *openid.LoginCookie) (*LoginCallback, error) { +func NewLoginCallback(c *Client, r *http.Request, cookie *openid.LoginCookie) (*LoginCallback, error) { if cookie == nil { return nil, fmt.Errorf("cookie is nil") } return &LoginCallback{ - client: c, + Client: c, cookie: cookie, - provider: p, request: r, requestParams: r.URL.Query(), }, nil @@ -66,7 +58,7 @@ func (in *LoginCallback) StateMismatchError() error { } func (in *LoginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, error) { - clientAssertion, err := in.client.MakeAssertion(time.Second * 30) + clientAssertion, err := in.MakeAssertion(time.Second * 30) if err != nil { return nil, fmt.Errorf("creating client assertion: %w", err) } @@ -79,12 +71,12 @@ func (in *LoginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, erro } code := in.requestParams.Get(openid.Code) - rawTokens, err := in.client.AuthCodeGrant(ctx, code, opts) + rawTokens, err := in.AuthCodeGrant(ctx, code, opts) if err != nil { return nil, fmt.Errorf("exchanging authorization code for token: %w", err) } - jwkSet, err := in.provider.GetPublicJwkSet(ctx) + jwkSet, err := in.jwksProvider.GetPublicJwkSet(ctx) if err != nil { return nil, fmt.Errorf("getting jwks: %w", err) } @@ -92,11 +84,11 @@ func (in *LoginCallback) RedeemTokens(ctx context.Context) (*openid.Tokens, erro tokens, err := openid.NewTokens(rawTokens, *jwkSet) if err != nil { // JWKS might not be up-to-date, so we'll want to force a refresh for the next attempt - _, _ = in.provider.RefreshPublicJwkSet(ctx) + _, _ = in.jwksProvider.RefreshPublicJwkSet(ctx) return nil, fmt.Errorf("parsing tokens: %w", err) } - err = tokens.IDToken.Validate(in.client.config(), in.cookie.Nonce) + err = tokens.IDToken.Validate(in.cfg, in.cookie.Nonce) if err != nil { return nil, fmt.Errorf("validating id_token: %w", err) } diff --git a/pkg/openid/client/login_callback_test.go b/pkg/openid/client/login_callback_test.go index 0bee0ca..22f88e3 100644 --- a/pkg/openid/client/login_callback_test.go +++ b/pkg/openid/client/login_callback_test.go @@ -115,8 +115,6 @@ func newLoginCallback(t *testing.T, url string) (*mock.IdentityProvider, *client idp.SetIngresses(mock.Ingress) req := idp.GetRequest(url) - cfg := idp.OpenIDConfig - redirect, err := urlpkg.LoginCallbackURL(req) assert.NoError(t, err) @@ -136,7 +134,7 @@ func newLoginCallback(t *testing.T, url string) (*mock.IdentityProvider, *client RedirectURI: redirect, } - loginCallback, err := newTestClientWithConfig(cfg).LoginCallback(req, idp.Provider, cookie) + loginCallback, err := idp.RelyingPartyHandler.GetClient().LoginCallback(req, cookie) assert.NoError(t, err) return idp, loginCallback diff --git a/pkg/openid/client/login_test.go b/pkg/openid/client/login_test.go index 91273cf..83c564f 100644 --- a/pkg/openid/client/login_test.go +++ b/pkg/openid/client/login_test.go @@ -66,11 +66,11 @@ func TestLogin_URL(t *testing.T) { openidConfig := mock.NewTestConfiguration(cfg) ingresses := mock.Ingresses(cfg) - c := client.NewClient(openidConfig) lsc := loginstatus.NewClient(cfg.Loginstatus, http.DefaultClient) + c := client.NewClient(openidConfig, lsc, nil) req := mock.NewGetRequest(test.url, ingresses) - result, err := c.Login(req, lsc) + result, err := c.Login(req) if test.error != nil { assert.True(t, errors.Is(err, test.error)) @@ -126,12 +126,12 @@ func TestLoginURL_WithResourceIndicator(t *testing.T) { openidConfig := mock.NewTestConfiguration(cfg) openidConfig.TestProvider.SetAuthorizationEndpoint("https://provider/authorize") - c := client.NewClient(openidConfig) + c := client.NewClient(openidConfig, lsc, nil) ingresses := mock.Ingresses(cfg) req := mock.NewGetRequest(mock.Ingress+"/oauth2/login", ingresses) - result, err := c.Login(req, lsc) + result, err := c.Login(req) assert.NoError(t, err) assert.NotEmpty(t, result) parsed, err := url.Parse(result.AuthCodeURL()) diff --git a/pkg/openid/client/logout.go b/pkg/openid/client/logout.go index ad93645..c3455d4 100644 --- a/pkg/openid/client/logout.go +++ b/pkg/openid/client/logout.go @@ -28,7 +28,7 @@ func NewLogout(c *Client, r *http.Request) (*Logout, error) { } func (in *Logout) SingleLogoutURL(idToken string) string { - endSessionEndpoint := in.config().Provider().EndSessionEndpointURL() + endSessionEndpoint := in.cfg.Provider().EndSessionEndpointURL() v := endSessionEndpoint.Query() v.Add(openid.PostLogoutRedirectURI, in.logoutCallbackURL) diff --git a/pkg/openid/client/logout_callback.go b/pkg/openid/client/logout_callback.go index 0353a85..73a541f 100644 --- a/pkg/openid/client/logout_callback.go +++ b/pkg/openid/client/logout_callback.go @@ -19,7 +19,7 @@ func NewLogoutCallback(c *Client, r *http.Request) *LogoutCallback { } func (in *LogoutCallback) PostLogoutRedirectURI() string { - redirect := in.config().Client().PostLogoutRedirectURI() + redirect := in.cfg.Client().PostLogoutRedirectURI() if len(redirect) > 0 { return redirect diff --git a/pkg/openid/provider/provider.go b/pkg/openid/provider/provider.go index 1f84aa0..193ae73 100644 --- a/pkg/openid/provider/provider.go +++ b/pkg/openid/provider/provider.go @@ -15,7 +15,7 @@ const ( JwkMinimumRefreshInterval = 5 * time.Second ) -type Provider struct { +type JwksProvider struct { config openidconfig.Provider jwksCache *jwk.Cache jwksLock *jwksLock @@ -26,7 +26,7 @@ type jwksLock struct { sync.Mutex } -func (p *Provider) GetPublicJwkSet(ctx context.Context) (*jwk.Set, error) { +func (p *JwksProvider) GetPublicJwkSet(ctx context.Context) (*jwk.Set, error) { url := p.config.JwksURI() set, err := p.jwksCache.Get(ctx, url) if err != nil { @@ -36,7 +36,7 @@ func (p *Provider) GetPublicJwkSet(ctx context.Context) (*jwk.Set, error) { return &set, nil } -func (p *Provider) RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error) { +func (p *JwksProvider) RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error) { p.jwksLock.Lock() defer p.jwksLock.Unlock() @@ -57,7 +57,7 @@ func (p *Provider) RefreshPublicJwkSet(ctx context.Context) (*jwk.Set, error) { return &set, nil } -func NewProvider(ctx context.Context, openidCfg openidconfig.Config) (*Provider, error) { +func NewJwksProvider(ctx context.Context, openidCfg openidconfig.Config) (*JwksProvider, error) { providerCfg := openidCfg.Provider() uri := providerCfg.JwksURI() @@ -74,7 +74,7 @@ func NewProvider(ctx context.Context, openidCfg openidconfig.Config) (*Provider, return nil, fmt.Errorf("initial fetch of jwks from provider: %w", err) } - return &Provider{ + return &JwksProvider{ config: providerCfg, jwksCache: cache, jwksLock: &jwksLock{},