From 7c98fe161ec7cdf34ab58cf2745074908f5202bf Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Sat, 29 Apr 2023 11:00:01 +0200 Subject: [PATCH] refactor(handler/reverseproxy): retrieve both session and token --- pkg/handler/handler.go | 9 ++------- pkg/handler/handler_sso_proxy.go | 9 ++------- pkg/handler/reverseproxy.go | 20 +++++++++++++++++--- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 7f26329..4041e85 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -90,13 +90,8 @@ func NewStandalone( }, nil } -func (s *Standalone) GetAccessToken(r *http.Request) (string, error) { - sess, err := s.SessionManager.GetOrRefresh(r) - if err != nil { - return "", err - } - - return sess.AccessToken() +func (s *Standalone) GetSession(r *http.Request) (*session.Session, error) { + return s.SessionManager.GetOrRefresh(r) } func (s *Standalone) GetAutoLogin() *autologin.AutoLogin { diff --git a/pkg/handler/handler_sso_proxy.go b/pkg/handler/handler_sso_proxy.go index 391bc3b..60b1729 100644 --- a/pkg/handler/handler_sso_proxy.go +++ b/pkg/handler/handler_sso_proxy.go @@ -70,13 +70,8 @@ func NewSSOProxy(cfg *config.Config, crypter crypto.Crypter) (*SSOProxy, error) }, nil } -func (s *SSOProxy) GetAccessToken(r *http.Request) (string, error) { - sess, err := s.SessionReader.Get(r) - if err != nil { - return "", err - } - - return sess.AccessToken() +func (s *SSOProxy) GetSession(r *http.Request) (*session.Session, error) { + return s.SessionReader.Get(r) } func (s *SSOProxy) GetAutoLogin() *autologin.AutoLogin { diff --git a/pkg/handler/reverseproxy.go b/pkg/handler/reverseproxy.go index 8bbc6ba..c4ba204 100644 --- a/pkg/handler/reverseproxy.go +++ b/pkg/handler/reverseproxy.go @@ -17,9 +17,9 @@ import ( ) type ReverseProxySource interface { - GetAccessToken(r *http.Request) (string, error) GetAutoLogin() *autologin.AutoLogin GetPath(r *http.Request) string + GetSession(r *http.Request) (*session.Session, error) } type ReverseProxy struct { @@ -65,10 +65,10 @@ func (rp *ReverseProxy) Handler(src ReverseProxySource, w http.ResponseWriter, r logger := mw.LogEntryFrom(r) isAuthenticated := false - accessToken, err := src.GetAccessToken(r) + _, accessToken, err := getSessionWithValidToken(src, r) switch { case err == nil: - // add authentication if session cookie and token checks out + // add authentication if session checks out isAuthenticated = true case errors.Is(err, context.Canceled): logger.Debugf("default: unauthenticated: %+v (client disconnected before we could respond)", err) @@ -106,6 +106,20 @@ func (rp *ReverseProxy) Handler(src ReverseProxySource, w http.ResponseWriter, r rp.ServeHTTP(w, r.WithContext(ctx)) } +func getSessionWithValidToken(src ReverseProxySource, r *http.Request) (*session.Session, string, error) { + sess, err := src.GetSession(r) + if err != nil { + return nil, "", err + } + + accessToken, err := sess.AccessToken() + if err != nil { + return nil, "", err + } + + return sess, accessToken, nil +} + type logrusErrorWriter struct{} func (w logrusErrorWriter) Write(p []byte) (n int, err error) {