mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-20 15:22:58 +00:00
fix(session/handler): ensure access token is not expired before proxying
This commit is contained in:
@@ -4,8 +4,6 @@ import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/handler/url"
|
||||
mw "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
@@ -16,8 +14,8 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
|
||||
logger := mw.LogEntryFrom(r).WithField("request_path", r.URL.Path)
|
||||
isAuthenticated := false
|
||||
|
||||
accessToken, ok := h.accessToken(r, logger)
|
||||
if ok {
|
||||
accessToken, err := h.Sessions.GetAccessToken(r)
|
||||
if err == nil {
|
||||
// add authentication if session cookie and token checks out
|
||||
isAuthenticated = true
|
||||
|
||||
@@ -26,6 +24,8 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
|
||||
isAuthenticated = false
|
||||
logger.Info("default: loginstatus was enabled, but no matching cookie was found; state is now unauthenticated")
|
||||
}
|
||||
} else if errors.Is(err, session.UnexpectedError) {
|
||||
logger.Errorf("default: getting session: %+v", err)
|
||||
}
|
||||
|
||||
if h.AutoLogin.NeedsLogin(r, isAuthenticated) {
|
||||
@@ -47,16 +47,3 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
h.ReverseProxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
func (h *Handler) accessToken(r *http.Request, logger *log.Entry) (string, bool) {
|
||||
sessionData, err := h.Sessions.GetOrRefresh(r)
|
||||
if err == nil && sessionData != nil && sessionData.HasAccessToken() {
|
||||
return sessionData.AccessToken, true
|
||||
}
|
||||
|
||||
if errors.Is(err, session.UnexpectedError) {
|
||||
logger.Errorf("default: getting session: %+v", err)
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -120,6 +120,14 @@ func NewMetadata(expiresIn time.Duration, endsIn time.Duration) *Metadata {
|
||||
}
|
||||
}
|
||||
|
||||
func (in *Metadata) IsExpired() bool {
|
||||
return time.Now().After(in.TokensExpireAt)
|
||||
}
|
||||
|
||||
func (in *Metadata) IsRefreshOnCooldown() bool {
|
||||
return time.Now().Before(in.RefreshCooldown())
|
||||
}
|
||||
|
||||
func (in *Metadata) NextRefresh() time.Time {
|
||||
// subtract the leeway to ensure that we refresh before expiry
|
||||
next := in.TokensExpireAt.Add(-RefreshLeeway)
|
||||
@@ -150,12 +158,8 @@ func (in *Metadata) RefreshCooldown() time.Time {
|
||||
return refreshed.Add(RefreshMinInterval)
|
||||
}
|
||||
|
||||
func (in *Metadata) RefreshOnCooldown() bool {
|
||||
return time.Now().Before(in.RefreshCooldown())
|
||||
}
|
||||
|
||||
func (in *Metadata) ShouldRefresh() bool {
|
||||
if in.RefreshOnCooldown() {
|
||||
if in.IsRefreshOnCooldown() {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -178,7 +182,7 @@ func (in *Metadata) Verbose() MetadataVerbose {
|
||||
SessionEndsInSeconds: toSeconds(endTime.Sub(now)),
|
||||
TokensExpireInSeconds: toSeconds(expireTime.Sub(now)),
|
||||
TokensNextRefreshInSeconds: toSeconds(nextRefreshTime.Sub(now)),
|
||||
TokensRefreshCooldown: in.RefreshOnCooldown(),
|
||||
TokensRefreshCooldown: in.IsRefreshOnCooldown(),
|
||||
TokensRefreshCooldownSeconds: toSeconds(in.RefreshCooldown().Sub(now)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,10 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
CookieNotFoundError = errors.New("cookie not found")
|
||||
CookieNotFoundError = errors.New("cookie not found")
|
||||
NoSessionDataError = errors.New("no session data")
|
||||
NoAccessTokenError = errors.New("no access token in session data")
|
||||
ExpiredAccessTokenError = errors.New("access token is expired")
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
@@ -112,6 +115,27 @@ func (h *Handler) Get(r *http.Request) (*Data, error) {
|
||||
return h.GetForKey(r, key)
|
||||
}
|
||||
|
||||
func (h *Handler) GetAccessToken(r *http.Request) (string, error) {
|
||||
sessionData, err := h.GetOrRefresh(r)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if sessionData == nil {
|
||||
return "", NoSessionDataError
|
||||
}
|
||||
|
||||
if !sessionData.HasAccessToken() {
|
||||
return "", NoAccessTokenError
|
||||
}
|
||||
|
||||
if sessionData.Metadata.IsExpired() {
|
||||
return "", ExpiredAccessTokenError
|
||||
}
|
||||
|
||||
return sessionData.AccessToken, nil
|
||||
}
|
||||
|
||||
// GetForID returns the session data for a given session ID.
|
||||
func (h *Handler) GetForID(r *http.Request, id string) (*Data, error) {
|
||||
key := h.Key(id)
|
||||
@@ -206,7 +230,7 @@ func (h *Handler) Key(sessionID string) string {
|
||||
|
||||
// Refresh refreshes the user's session and returns the updated session data.
|
||||
func (h *Handler) Refresh(r *http.Request, key string, data *Data) (*Data, error) {
|
||||
if !h.refreshEnabled || !data.HasRefreshToken() || data.Metadata.RefreshOnCooldown() {
|
||||
if !h.refreshEnabled || !data.HasRefreshToken() || data.Metadata.IsRefreshOnCooldown() {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user