From 5ec969981de966f6deb39a742e8cca2d04ae33b6 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Fri, 26 Aug 2022 17:58:39 +0200 Subject: [PATCH] fix(session/handler): ensure access token is not expired before proxying --- pkg/handler/handler_default.go | 21 ++++----------------- pkg/session/data.go | 16 ++++++++++------ pkg/session/handler.go | 28 ++++++++++++++++++++++++++-- 3 files changed, 40 insertions(+), 25 deletions(-) diff --git a/pkg/handler/handler_default.go b/pkg/handler/handler_default.go index bd899dc..78b82a5 100644 --- a/pkg/handler/handler_default.go +++ b/pkg/handler/handler_default.go @@ -4,8 +4,6 @@ import ( "errors" "net/http" - log "github.com/sirupsen/logrus" - "github.com/nais/wonderwall/pkg/handler/url" mw "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/session" @@ -16,8 +14,8 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { logger := mw.LogEntryFrom(r).WithField("request_path", r.URL.Path) isAuthenticated := false - accessToken, ok := h.accessToken(r, logger) - if ok { + accessToken, err := h.Sessions.GetAccessToken(r) + if err == nil { // add authentication if session cookie and token checks out isAuthenticated = true @@ -26,6 +24,8 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { isAuthenticated = false logger.Info("default: loginstatus was enabled, but no matching cookie was found; state is now unauthenticated") } + } else if errors.Is(err, session.UnexpectedError) { + logger.Errorf("default: getting session: %+v", err) } if h.AutoLogin.NeedsLogin(r, isAuthenticated) { @@ -47,16 +47,3 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { h.ReverseProxy.ServeHTTP(w, r.WithContext(ctx)) } - -func (h *Handler) accessToken(r *http.Request, logger *log.Entry) (string, bool) { - sessionData, err := h.Sessions.GetOrRefresh(r) - if err == nil && sessionData != nil && sessionData.HasAccessToken() { - return sessionData.AccessToken, true - } - - if errors.Is(err, session.UnexpectedError) { - logger.Errorf("default: getting session: %+v", err) - } - - return "", false -} diff --git a/pkg/session/data.go b/pkg/session/data.go index 4d32a71..f117238 100644 --- a/pkg/session/data.go +++ b/pkg/session/data.go @@ -120,6 +120,14 @@ func NewMetadata(expiresIn time.Duration, endsIn time.Duration) *Metadata { } } +func (in *Metadata) IsExpired() bool { + return time.Now().After(in.TokensExpireAt) +} + +func (in *Metadata) IsRefreshOnCooldown() bool { + return time.Now().Before(in.RefreshCooldown()) +} + func (in *Metadata) NextRefresh() time.Time { // subtract the leeway to ensure that we refresh before expiry next := in.TokensExpireAt.Add(-RefreshLeeway) @@ -150,12 +158,8 @@ func (in *Metadata) RefreshCooldown() time.Time { return refreshed.Add(RefreshMinInterval) } -func (in *Metadata) RefreshOnCooldown() bool { - return time.Now().Before(in.RefreshCooldown()) -} - func (in *Metadata) ShouldRefresh() bool { - if in.RefreshOnCooldown() { + if in.IsRefreshOnCooldown() { return false } @@ -178,7 +182,7 @@ func (in *Metadata) Verbose() MetadataVerbose { SessionEndsInSeconds: toSeconds(endTime.Sub(now)), TokensExpireInSeconds: toSeconds(expireTime.Sub(now)), TokensNextRefreshInSeconds: toSeconds(nextRefreshTime.Sub(now)), - TokensRefreshCooldown: in.RefreshOnCooldown(), + TokensRefreshCooldown: in.IsRefreshOnCooldown(), TokensRefreshCooldownSeconds: toSeconds(in.RefreshCooldown().Sub(now)), } } diff --git a/pkg/session/handler.go b/pkg/session/handler.go index dc08ecb..8448609 100644 --- a/pkg/session/handler.go +++ b/pkg/session/handler.go @@ -22,7 +22,10 @@ import ( ) var ( - CookieNotFoundError = errors.New("cookie not found") + CookieNotFoundError = errors.New("cookie not found") + NoSessionDataError = errors.New("no session data") + NoAccessTokenError = errors.New("no access token in session data") + ExpiredAccessTokenError = errors.New("access token is expired") ) type Handler struct { @@ -112,6 +115,27 @@ func (h *Handler) Get(r *http.Request) (*Data, error) { return h.GetForKey(r, key) } +func (h *Handler) GetAccessToken(r *http.Request) (string, error) { + sessionData, err := h.GetOrRefresh(r) + if err != nil { + return "", err + } + + if sessionData == nil { + return "", NoSessionDataError + } + + if !sessionData.HasAccessToken() { + return "", NoAccessTokenError + } + + if sessionData.Metadata.IsExpired() { + return "", ExpiredAccessTokenError + } + + return sessionData.AccessToken, nil +} + // GetForID returns the session data for a given session ID. func (h *Handler) GetForID(r *http.Request, id string) (*Data, error) { key := h.Key(id) @@ -206,7 +230,7 @@ func (h *Handler) Key(sessionID string) string { // Refresh refreshes the user's session and returns the updated session data. func (h *Handler) Refresh(r *http.Request, key string, data *Data) (*Data, error) { - if !h.refreshEnabled || !data.HasRefreshToken() || data.Metadata.RefreshOnCooldown() { + if !h.refreshEnabled || !data.HasRefreshToken() || data.Metadata.IsRefreshOnCooldown() { return data, nil }