From a4e4fc752eaa79c67c29a22a87b3952caf2c6a79 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Tue, 31 Jan 2023 15:23:08 +0100 Subject: [PATCH] refactor(handler): remove provider name getter from handler --- cmd/wonderwall/main.go | 2 +- pkg/handler/handler.go | 20 ++++++------ pkg/handler/handler_sso_proxy.go | 5 --- pkg/handler/handler_sso_server.go | 5 --- pkg/handler/handler_standard.go | 54 +++++++++++++------------------ pkg/mock/openid.go | 18 +++-------- pkg/router/router.go | 9 +++--- 7 files changed, 43 insertions(+), 70 deletions(-) diff --git a/cmd/wonderwall/main.go b/cmd/wonderwall/main.go index 5d7fbe5..b0c33d2 100644 --- a/cmd/wonderwall/main.go +++ b/cmd/wonderwall/main.go @@ -59,7 +59,7 @@ func run() error { if err != nil { return fmt.Errorf("initializing routing handler: %w", err) } - r := router.New(h) + r := router.New(h, cfg) go func() { err := metrics.Handle(cfg.MetricsBindAddress, cfg.OpenID.Provider) diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 1941752..e2b93b8 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -47,15 +47,15 @@ func NewHandler( } return &StandardHandler{ - autoLogin: autoLogin, - client: openidClient, - config: cfg, - cookieOptions: cookieOpts, - crypter: crypter, - ingresses: ingresses, - loginstatus: loginstatusClient, - openidConfig: openidConfig, - sessions: sessionHandler, - upstreamProxy: NewReverseProxy(cfg.UpstreamHost), + AutoLogin: autoLogin, + Client: openidClient, + Config: cfg, + CookieOptions: cookieOpts, + Crypter: crypter, + Ingresses: ingresses, + Loginstatus: loginstatusClient, + OpenidConfig: openidConfig, + Sessions: sessionHandler, + UpstreamProxy: NewReverseProxy(cfg.UpstreamHost), }, nil } diff --git a/pkg/handler/handler_sso_proxy.go b/pkg/handler/handler_sso_proxy.go index a09092f..4daf43c 100644 --- a/pkg/handler/handler_sso_proxy.go +++ b/pkg/handler/handler_sso_proxy.go @@ -61,8 +61,3 @@ func (s *SSOProxyHandler) GetIngresses() *ingress.Ingresses { //TODO implement me panic("implement me") } - -func (s *SSOProxyHandler) GetProviderName() string { - //TODO implement me - panic("implement me") -} diff --git a/pkg/handler/handler_sso_server.go b/pkg/handler/handler_sso_server.go index ed41c38..bab5a0f 100644 --- a/pkg/handler/handler_sso_server.go +++ b/pkg/handler/handler_sso_server.go @@ -61,8 +61,3 @@ func (s *SSOServerHandler) GetIngresses() *ingress.Ingresses { //TODO implement me panic("implement me") } - -func (s *SSOServerHandler) GetProviderName() string { - //TODO implement me - panic("implement me") -} diff --git a/pkg/handler/handler_standard.go b/pkg/handler/handler_standard.go index 8bd103f..c91a258 100644 --- a/pkg/handler/handler_standard.go +++ b/pkg/handler/handler_standard.go @@ -20,37 +20,37 @@ import ( var _ router.Source = &StandardHandler{} type StandardHandler struct { - autoLogin *autologin.AutoLogin - client *openidclient.Client - config *config.Config - cookieOptions cookie.Options - crypter crypto.Crypter - ingresses *ingress.Ingresses - loginstatus *loginstatus.Loginstatus - openidConfig openidconfig.Config - sessions *sessionStore.Handler - upstreamProxy *ReverseProxy + AutoLogin *autologin.AutoLogin + Client *openidclient.Client + Config *config.Config + CookieOptions cookie.Options + Crypter crypto.Crypter + Ingresses *ingress.Ingresses + Loginstatus *loginstatus.Loginstatus + OpenidConfig openidconfig.Config + Sessions *sessionStore.Handler + UpstreamProxy *ReverseProxy } func (s *StandardHandler) GetAutoLogin() *autologin.AutoLogin { - return s.autoLogin + return s.AutoLogin } func (s *StandardHandler) GetClient() *openidclient.Client { - return s.client + return s.Client } func (s *StandardHandler) GetCookieOptions() cookie.Options { - return s.cookieOptions + return s.CookieOptions } func (s *StandardHandler) GetCookieOptsPathAware(r *http.Request) cookie.Options { path := s.GetPath(r) - return s.cookieOptions.WithPath(path) + return s.CookieOptions.WithPath(path) } func (s *StandardHandler) GetCrypter() crypto.Crypter { - return s.crypter + return s.Crypter } func (s *StandardHandler) GetErrorHandler() errorhandler.Handler { @@ -58,40 +58,32 @@ func (s *StandardHandler) GetErrorHandler() errorhandler.Handler { } func (s *StandardHandler) GetErrorPath() string { - return s.config.ErrorPath + return s.Config.ErrorPath } func (s *StandardHandler) GetIngresses() *ingress.Ingresses { - return s.ingresses -} - -func (s *StandardHandler) SetIngresses(ingresses *ingress.Ingresses) { - s.ingresses = ingresses + return s.Ingresses } func (s *StandardHandler) GetLoginstatus() *loginstatus.Loginstatus { - return s.loginstatus + return s.Loginstatus } func (s *StandardHandler) GetPath(r *http.Request) string { path, ok := middleware.PathFrom(r.Context()) if !ok { - path = s.GetIngresses().MatchingPath(r) + path = s.Ingresses.MatchingPath(r) } return path } -func (s *StandardHandler) GetProviderName() string { - return string(s.config.OpenID.Provider) -} - func (s *StandardHandler) GetSessions() *sessionStore.Handler { - return s.sessions + return s.Sessions } func (s *StandardHandler) GetSessionConfig() config.Session { - return s.config.Session + return s.Config.Session } func (s *StandardHandler) Login(w http.ResponseWriter, r *http.Request) { @@ -129,7 +121,7 @@ func (s *StandardHandler) Session(w http.ResponseWriter, r *http.Request) { } func (s *StandardHandler) SessionRefresh(w http.ResponseWriter, r *http.Request) { - if !s.config.Session.Refresh { + if !s.Config.Session.Refresh { http.NotFound(w, r) return } @@ -138,5 +130,5 @@ func (s *StandardHandler) SessionRefresh(w http.ResponseWriter, r *http.Request) } func (s *StandardHandler) ReverseProxy(w http.ResponseWriter, r *http.Request) { - s.upstreamProxy.Handler(s, w, r) + s.UpstreamProxy.Handler(s, w, r) } diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index f1669de..0507e5b 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -20,31 +20,20 @@ import ( "github.com/nais/wonderwall/pkg/cookie" "github.com/nais/wonderwall/pkg/crypto" handlerpkg "github.com/nais/wonderwall/pkg/handler" - errorhandler "github.com/nais/wonderwall/pkg/handler/error" "github.com/nais/wonderwall/pkg/ingress" "github.com/nais/wonderwall/pkg/openid" openidclient "github.com/nais/wonderwall/pkg/openid/client" openidconfig "github.com/nais/wonderwall/pkg/openid/config" scopespkg "github.com/nais/wonderwall/pkg/openid/scopes" "github.com/nais/wonderwall/pkg/router" - "github.com/nais/wonderwall/pkg/session" ) -type RelyingPartyHandler interface { - router.Source - GetClient() *openidclient.Client - GetCrypter() crypto.Crypter - GetErrorHandler() errorhandler.Handler - GetSessions() *session.Handler - SetIngresses(ingresses *ingress.Ingresses) -} - type IdentityProvider struct { Cfg *config.Config OpenIDConfig *TestConfiguration ProviderHandler *IdentityProviderHandler ProviderServer *httptest.Server - RelyingPartyHandler RelyingPartyHandler + RelyingPartyHandler *handlerpkg.StandardHandler RelyingPartyServer *httptest.Server } @@ -76,7 +65,7 @@ func (in *IdentityProvider) SetIngresses(ingresses ...string) { panic(err) } - in.RelyingPartyHandler.SetIngresses(parsed) + in.RelyingPartyHandler.Ingresses = parsed } func (in *IdentityProvider) GetRequest(target string) *http.Request { @@ -99,12 +88,13 @@ func NewIdentityProvider(cfg *config.Config) *IdentityProvider { crypter := crypto.NewCrypter([]byte(cfg.EncryptionKey)) cookieOpts := cookie.DefaultOptions().WithSecure(false) + rpHandler, err := handlerpkg.NewHandler(cfg, cookieOpts, jwksProvider, openidConfig, crypter) if err != nil { panic(err) } - rpRouter := router.New(rpHandler) + rpRouter := router.New(rpHandler, cfg) rpServer := httptest.NewServer(rpRouter) ip := &IdentityProvider{ diff --git a/pkg/router/router.go b/pkg/router/router.go index 6782adb..a3b489e 100644 --- a/pkg/router/router.go +++ b/pkg/router/router.go @@ -6,6 +6,7 @@ import ( "github.com/go-chi/chi/v5" chi_middleware "github.com/go-chi/chi/v5/middleware" + "github.com/nais/wonderwall/pkg/config" "github.com/nais/wonderwall/pkg/ingress" "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/router/paths" @@ -39,13 +40,13 @@ type Handlers interface { type Config interface { GetIngresses() *ingress.Ingresses - GetProviderName() string } -func New(src Source) chi.Router { +func New(src Source, cfg *config.Config) chi.Router { + providerName := string(cfg.OpenID.Provider) ingressMw := middleware.Ingress(src) - prometheus := middleware.Prometheus(src.GetProviderName()) - logentry := middleware.LogEntry(src.GetProviderName()) + prometheus := middleware.Prometheus(providerName) + logentry := middleware.LogEntry(providerName) r := chi.NewRouter() r.Use(middleware.CorrelationIDHandler)