From 0e73c9b4d8b52290ccaf5f8621d9bbde80c6c2ef Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Mon, 6 Feb 2023 10:51:02 +0100 Subject: [PATCH] refactor(mock): configure relying party ingress before server start --- pkg/handler/error/error_test.go | 13 ++++---- pkg/mock/openid.go | 40 ++++++++++++++---------- pkg/openid/client/login_callback_test.go | 1 - 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/pkg/handler/error/error_test.go b/pkg/handler/error/error_test.go index e1802c1..fe95b4a 100644 --- a/pkg/handler/error/error_test.go +++ b/pkg/handler/error/error_test.go @@ -71,12 +71,6 @@ func TestHandler_Error(t *testing.T) { } func TestHandler_Retry(t *testing.T) { - cfg := mock.Config() - idp := mock.NewIdentityProvider(cfg) - defer idp.Close() - - handler := idp.RelyingPartyHandler.GetErrorHandler() - get := func(url string) *http.Request { return httptest.NewRequest(http.MethodGet, url, nil) } @@ -189,7 +183,12 @@ func TestHandler_Retry(t *testing.T) { test.ingress = mock.Ingress } - idp.SetIngresses(test.ingress) + cfg := mock.Config() + cfg.Ingresses = []string{test.ingress} + idp := mock.NewIdentityProvider(cfg) + defer idp.Close() + + handler := idp.RelyingPartyHandler.GetErrorHandler() ing, err := ingress.ParseIngress(test.ingress) assert.NoError(t, err) diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index a702e0e..a52f5e0 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -20,7 +20,6 @@ import ( "github.com/nais/wonderwall/pkg/cookie" "github.com/nais/wonderwall/pkg/crypto" handlerpkg "github.com/nais/wonderwall/pkg/handler" - "github.com/nais/wonderwall/pkg/ingress" "github.com/nais/wonderwall/pkg/openid" openidclient "github.com/nais/wonderwall/pkg/openid/client" openidconfig "github.com/nais/wonderwall/pkg/openid/config" @@ -57,22 +56,14 @@ func (in *IdentityProvider) RelyingPartyClient() *http.Client { return client } -func (in *IdentityProvider) SetIngresses(ingresses ...string) { - in.Cfg.Ingresses = ingresses - - parsed, err := ingress.ParseIngresses(in.Cfg) - if err != nil { - panic(err) - } - - in.RelyingPartyHandler.Ingresses = parsed -} - func (in *IdentityProvider) GetRequest(target string) *http.Request { return NewGetRequest(target, in.RelyingPartyHandler.GetIngresses()) } func NewIdentityProvider(cfg *config.Config) *IdentityProvider { + rpServer := newRelyingPartyServer() + cfg.Ingresses = append(cfg.Ingresses, rpServer.GetURL()) + openidConfig := NewTestConfiguration(cfg) jwksProvider := NewTestJwksProvider() handler := newIdentityProviderHandler(jwksProvider, openidConfig) @@ -95,19 +86,18 @@ func NewIdentityProvider(cfg *config.Config) *IdentityProvider { } rpRouter := router.New(rpHandler, cfg) - rpServer := httptest.NewServer(rpRouter) + rpServer.SetHandler(rpRouter) + rpServer.Start() ip := &IdentityProvider{ Cfg: cfg, RelyingPartyHandler: rpHandler, - RelyingPartyServer: rpServer, + RelyingPartyServer: rpServer.Server, OpenIDConfig: openidConfig, ProviderHandler: handler, ProviderServer: server, } - // reconfigure ingresses after Relying Party server is started - ip.SetIngresses(rpServer.URL) return ip } @@ -572,3 +562,21 @@ func (ip *IdentityProviderHandler) EndSession(w http.ResponseWriter, r *http.Req http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) } + +type relyingPartyServer struct { + *httptest.Server +} + +func newRelyingPartyServer() *relyingPartyServer { + return &relyingPartyServer{httptest.NewUnstartedServer(nil)} +} + +func (in *relyingPartyServer) GetURL() string { + return "http://" + in.Listener.Addr().String() +} + +func (in *relyingPartyServer) SetHandler(handler http.Handler) { + in.Config = &http.Server{ + Handler: handler, + } +} diff --git a/pkg/openid/client/login_callback_test.go b/pkg/openid/client/login_callback_test.go index 22f88e3..b8a3d01 100644 --- a/pkg/openid/client/login_callback_test.go +++ b/pkg/openid/client/login_callback_test.go @@ -112,7 +112,6 @@ func TestLoginCallback_RedeemTokens(t *testing.T) { func newLoginCallback(t *testing.T, url string) (*mock.IdentityProvider, *client.LoginCallback) { idp := mock.NewIdentityProvider(mock.Config()) - idp.SetIngresses(mock.Ingress) req := idp.GetRequest(url) redirect, err := urlpkg.LoginCallbackURL(req)