fix(session/handler): ensure access token is not expired before proxying

This commit is contained in:
Trong Huu Nguyen
2022-08-26 17:58:39 +02:00
parent d5bbca9897
commit 5ec969981d
3 changed files with 40 additions and 25 deletions

View File

@@ -4,8 +4,6 @@ import (
"errors"
"net/http"
log "github.com/sirupsen/logrus"
"github.com/nais/wonderwall/pkg/handler/url"
mw "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/session"
@@ -16,8 +14,8 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
logger := mw.LogEntryFrom(r).WithField("request_path", r.URL.Path)
isAuthenticated := false
accessToken, ok := h.accessToken(r, logger)
if ok {
accessToken, err := h.Sessions.GetAccessToken(r)
if err == nil {
// add authentication if session cookie and token checks out
isAuthenticated = true
@@ -26,6 +24,8 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
isAuthenticated = false
logger.Info("default: loginstatus was enabled, but no matching cookie was found; state is now unauthenticated")
}
} else if errors.Is(err, session.UnexpectedError) {
logger.Errorf("default: getting session: %+v", err)
}
if h.AutoLogin.NeedsLogin(r, isAuthenticated) {
@@ -47,16 +47,3 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
h.ReverseProxy.ServeHTTP(w, r.WithContext(ctx))
}
func (h *Handler) accessToken(r *http.Request, logger *log.Entry) (string, bool) {
sessionData, err := h.Sessions.GetOrRefresh(r)
if err == nil && sessionData != nil && sessionData.HasAccessToken() {
return sessionData.AccessToken, true
}
if errors.Is(err, session.UnexpectedError) {
logger.Errorf("default: getting session: %+v", err)
}
return "", false
}

View File

@@ -120,6 +120,14 @@ func NewMetadata(expiresIn time.Duration, endsIn time.Duration) *Metadata {
}
}
func (in *Metadata) IsExpired() bool {
return time.Now().After(in.TokensExpireAt)
}
func (in *Metadata) IsRefreshOnCooldown() bool {
return time.Now().Before(in.RefreshCooldown())
}
func (in *Metadata) NextRefresh() time.Time {
// subtract the leeway to ensure that we refresh before expiry
next := in.TokensExpireAt.Add(-RefreshLeeway)
@@ -150,12 +158,8 @@ func (in *Metadata) RefreshCooldown() time.Time {
return refreshed.Add(RefreshMinInterval)
}
func (in *Metadata) RefreshOnCooldown() bool {
return time.Now().Before(in.RefreshCooldown())
}
func (in *Metadata) ShouldRefresh() bool {
if in.RefreshOnCooldown() {
if in.IsRefreshOnCooldown() {
return false
}
@@ -178,7 +182,7 @@ func (in *Metadata) Verbose() MetadataVerbose {
SessionEndsInSeconds: toSeconds(endTime.Sub(now)),
TokensExpireInSeconds: toSeconds(expireTime.Sub(now)),
TokensNextRefreshInSeconds: toSeconds(nextRefreshTime.Sub(now)),
TokensRefreshCooldown: in.RefreshOnCooldown(),
TokensRefreshCooldown: in.IsRefreshOnCooldown(),
TokensRefreshCooldownSeconds: toSeconds(in.RefreshCooldown().Sub(now)),
}
}

View File

@@ -22,7 +22,10 @@ import (
)
var (
CookieNotFoundError = errors.New("cookie not found")
CookieNotFoundError = errors.New("cookie not found")
NoSessionDataError = errors.New("no session data")
NoAccessTokenError = errors.New("no access token in session data")
ExpiredAccessTokenError = errors.New("access token is expired")
)
type Handler struct {
@@ -112,6 +115,27 @@ func (h *Handler) Get(r *http.Request) (*Data, error) {
return h.GetForKey(r, key)
}
func (h *Handler) GetAccessToken(r *http.Request) (string, error) {
sessionData, err := h.GetOrRefresh(r)
if err != nil {
return "", err
}
if sessionData == nil {
return "", NoSessionDataError
}
if !sessionData.HasAccessToken() {
return "", NoAccessTokenError
}
if sessionData.Metadata.IsExpired() {
return "", ExpiredAccessTokenError
}
return sessionData.AccessToken, nil
}
// GetForID returns the session data for a given session ID.
func (h *Handler) GetForID(r *http.Request, id string) (*Data, error) {
key := h.Key(id)
@@ -206,7 +230,7 @@ func (h *Handler) Key(sessionID string) string {
// Refresh refreshes the user's session and returns the updated session data.
func (h *Handler) Refresh(r *http.Request, key string, data *Data) (*Data, error) {
if !h.refreshEnabled || !data.HasRefreshToken() || data.Metadata.RefreshOnCooldown() {
if !h.refreshEnabled || !data.HasRefreshToken() || data.Metadata.IsRefreshOnCooldown() {
return data, nil
}