refactor(handler): remove provider name getter from handler

This commit is contained in:
Trong Huu Nguyen
2023-01-31 15:23:08 +01:00
parent 3d08d0b4b0
commit a4e4fc752e
7 changed files with 43 additions and 70 deletions

View File

@@ -59,7 +59,7 @@ func run() error {
if err != nil {
return fmt.Errorf("initializing routing handler: %w", err)
}
r := router.New(h)
r := router.New(h, cfg)
go func() {
err := metrics.Handle(cfg.MetricsBindAddress, cfg.OpenID.Provider)

View File

@@ -47,15 +47,15 @@ func NewHandler(
}
return &StandardHandler{
autoLogin: autoLogin,
client: openidClient,
config: cfg,
cookieOptions: cookieOpts,
crypter: crypter,
ingresses: ingresses,
loginstatus: loginstatusClient,
openidConfig: openidConfig,
sessions: sessionHandler,
upstreamProxy: NewReverseProxy(cfg.UpstreamHost),
AutoLogin: autoLogin,
Client: openidClient,
Config: cfg,
CookieOptions: cookieOpts,
Crypter: crypter,
Ingresses: ingresses,
Loginstatus: loginstatusClient,
OpenidConfig: openidConfig,
Sessions: sessionHandler,
UpstreamProxy: NewReverseProxy(cfg.UpstreamHost),
}, nil
}

View File

@@ -61,8 +61,3 @@ func (s *SSOProxyHandler) GetIngresses() *ingress.Ingresses {
//TODO implement me
panic("implement me")
}
func (s *SSOProxyHandler) GetProviderName() string {
//TODO implement me
panic("implement me")
}

View File

@@ -61,8 +61,3 @@ func (s *SSOServerHandler) GetIngresses() *ingress.Ingresses {
//TODO implement me
panic("implement me")
}
func (s *SSOServerHandler) GetProviderName() string {
//TODO implement me
panic("implement me")
}

View File

@@ -20,37 +20,37 @@ import (
var _ router.Source = &StandardHandler{}
type StandardHandler struct {
autoLogin *autologin.AutoLogin
client *openidclient.Client
config *config.Config
cookieOptions cookie.Options
crypter crypto.Crypter
ingresses *ingress.Ingresses
loginstatus *loginstatus.Loginstatus
openidConfig openidconfig.Config
sessions *sessionStore.Handler
upstreamProxy *ReverseProxy
AutoLogin *autologin.AutoLogin
Client *openidclient.Client
Config *config.Config
CookieOptions cookie.Options
Crypter crypto.Crypter
Ingresses *ingress.Ingresses
Loginstatus *loginstatus.Loginstatus
OpenidConfig openidconfig.Config
Sessions *sessionStore.Handler
UpstreamProxy *ReverseProxy
}
func (s *StandardHandler) GetAutoLogin() *autologin.AutoLogin {
return s.autoLogin
return s.AutoLogin
}
func (s *StandardHandler) GetClient() *openidclient.Client {
return s.client
return s.Client
}
func (s *StandardHandler) GetCookieOptions() cookie.Options {
return s.cookieOptions
return s.CookieOptions
}
func (s *StandardHandler) GetCookieOptsPathAware(r *http.Request) cookie.Options {
path := s.GetPath(r)
return s.cookieOptions.WithPath(path)
return s.CookieOptions.WithPath(path)
}
func (s *StandardHandler) GetCrypter() crypto.Crypter {
return s.crypter
return s.Crypter
}
func (s *StandardHandler) GetErrorHandler() errorhandler.Handler {
@@ -58,40 +58,32 @@ func (s *StandardHandler) GetErrorHandler() errorhandler.Handler {
}
func (s *StandardHandler) GetErrorPath() string {
return s.config.ErrorPath
return s.Config.ErrorPath
}
func (s *StandardHandler) GetIngresses() *ingress.Ingresses {
return s.ingresses
}
func (s *StandardHandler) SetIngresses(ingresses *ingress.Ingresses) {
s.ingresses = ingresses
return s.Ingresses
}
func (s *StandardHandler) GetLoginstatus() *loginstatus.Loginstatus {
return s.loginstatus
return s.Loginstatus
}
func (s *StandardHandler) GetPath(r *http.Request) string {
path, ok := middleware.PathFrom(r.Context())
if !ok {
path = s.GetIngresses().MatchingPath(r)
path = s.Ingresses.MatchingPath(r)
}
return path
}
func (s *StandardHandler) GetProviderName() string {
return string(s.config.OpenID.Provider)
}
func (s *StandardHandler) GetSessions() *sessionStore.Handler {
return s.sessions
return s.Sessions
}
func (s *StandardHandler) GetSessionConfig() config.Session {
return s.config.Session
return s.Config.Session
}
func (s *StandardHandler) Login(w http.ResponseWriter, r *http.Request) {
@@ -129,7 +121,7 @@ func (s *StandardHandler) Session(w http.ResponseWriter, r *http.Request) {
}
func (s *StandardHandler) SessionRefresh(w http.ResponseWriter, r *http.Request) {
if !s.config.Session.Refresh {
if !s.Config.Session.Refresh {
http.NotFound(w, r)
return
}
@@ -138,5 +130,5 @@ func (s *StandardHandler) SessionRefresh(w http.ResponseWriter, r *http.Request)
}
func (s *StandardHandler) ReverseProxy(w http.ResponseWriter, r *http.Request) {
s.upstreamProxy.Handler(s, w, r)
s.UpstreamProxy.Handler(s, w, r)
}

View File

@@ -20,31 +20,20 @@ import (
"github.com/nais/wonderwall/pkg/cookie"
"github.com/nais/wonderwall/pkg/crypto"
handlerpkg "github.com/nais/wonderwall/pkg/handler"
errorhandler "github.com/nais/wonderwall/pkg/handler/error"
"github.com/nais/wonderwall/pkg/ingress"
"github.com/nais/wonderwall/pkg/openid"
openidclient "github.com/nais/wonderwall/pkg/openid/client"
openidconfig "github.com/nais/wonderwall/pkg/openid/config"
scopespkg "github.com/nais/wonderwall/pkg/openid/scopes"
"github.com/nais/wonderwall/pkg/router"
"github.com/nais/wonderwall/pkg/session"
)
type RelyingPartyHandler interface {
router.Source
GetClient() *openidclient.Client
GetCrypter() crypto.Crypter
GetErrorHandler() errorhandler.Handler
GetSessions() *session.Handler
SetIngresses(ingresses *ingress.Ingresses)
}
type IdentityProvider struct {
Cfg *config.Config
OpenIDConfig *TestConfiguration
ProviderHandler *IdentityProviderHandler
ProviderServer *httptest.Server
RelyingPartyHandler RelyingPartyHandler
RelyingPartyHandler *handlerpkg.StandardHandler
RelyingPartyServer *httptest.Server
}
@@ -76,7 +65,7 @@ func (in *IdentityProvider) SetIngresses(ingresses ...string) {
panic(err)
}
in.RelyingPartyHandler.SetIngresses(parsed)
in.RelyingPartyHandler.Ingresses = parsed
}
func (in *IdentityProvider) GetRequest(target string) *http.Request {
@@ -99,12 +88,13 @@ func NewIdentityProvider(cfg *config.Config) *IdentityProvider {
crypter := crypto.NewCrypter([]byte(cfg.EncryptionKey))
cookieOpts := cookie.DefaultOptions().WithSecure(false)
rpHandler, err := handlerpkg.NewHandler(cfg, cookieOpts, jwksProvider, openidConfig, crypter)
if err != nil {
panic(err)
}
rpRouter := router.New(rpHandler)
rpRouter := router.New(rpHandler, cfg)
rpServer := httptest.NewServer(rpRouter)
ip := &IdentityProvider{

View File

@@ -6,6 +6,7 @@ import (
"github.com/go-chi/chi/v5"
chi_middleware "github.com/go-chi/chi/v5/middleware"
"github.com/nais/wonderwall/pkg/config"
"github.com/nais/wonderwall/pkg/ingress"
"github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/router/paths"
@@ -39,13 +40,13 @@ type Handlers interface {
type Config interface {
GetIngresses() *ingress.Ingresses
GetProviderName() string
}
func New(src Source) chi.Router {
func New(src Source, cfg *config.Config) chi.Router {
providerName := string(cfg.OpenID.Provider)
ingressMw := middleware.Ingress(src)
prometheus := middleware.Prometheus(src.GetProviderName())
logentry := middleware.LogEntry(src.GetProviderName())
prometheus := middleware.Prometheus(providerName)
logentry := middleware.LogEntry(providerName)
r := chi.NewRouter()
r.Use(middleware.CorrelationIDHandler)