From 5b33313ccb3ca11fc5e9eba804eda52f49a5e552 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Tue, 21 Feb 2023 14:12:56 +0100 Subject: [PATCH] feat(session): add GetOrRefresh method --- pkg/session/session.go | 4 ++++ pkg/session/session_manager.go | 22 ++++++++++++++++++++++ pkg/session/session_reader.go | 2 +- pkg/session/ticket.go | 4 ++-- 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/pkg/session/session.go b/pkg/session/session.go index 71c369a..0492aa2 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -42,6 +42,10 @@ type Writer interface { type Manager interface { Reader Writer + + // GetOrRefresh returns the session for a given http.Request. If the tokens within the session are expired and the + // session is still valid, it will automatically attempt to refresh and update the session. + GetOrRefresh(r *http.Request) (*Session, error) } type Session struct { diff --git a/pkg/session/session_manager.go b/pkg/session/session_manager.go index 25f91c1..55fc6c1 100644 --- a/pkg/session/session_manager.go +++ b/pkg/session/session_manager.go @@ -102,6 +102,28 @@ func (in *manager) DeleteForExternalID(ctx context.Context, id string) error { return in.deleteForKey(ctx, key) } +func (in *manager) GetOrRefresh(r *http.Request) (*Session, error) { + sess, err := in.Get(r) + if err != nil { + return nil, fmt.Errorf("getting session: %w", err) + } + + if !sess.ShouldRefresh() { + return sess, nil + } + + refreshed, err := in.Refresh(r, sess) + if errors.Is(err, ErrInvalidExternal) || errors.Is(err, ErrInvalid) { + return nil, err + } else if err != nil { + mw.LogEntryFrom(r).Warnf("session: could not refresh tokens; falling back to existing tokens: %+v", err) + } else { + sess = refreshed + } + + return sess, nil +} + func (in *manager) Refresh(r *http.Request, sess *Session) (*Session, error) { if !in.cfg.Session.Refresh || !sess.CanRefresh() { return sess, nil diff --git a/pkg/session/session_reader.go b/pkg/session/session_reader.go index cf554fd..d7bad2f 100644 --- a/pkg/session/session_reader.go +++ b/pkg/session/session_reader.go @@ -37,7 +37,7 @@ func NewReader(cfg *config.Config, cookieCrypter crypto.Crypter) (Reader, error) func (in *reader) Get(r *http.Request) (*Session, error) { ticket, err := getTicket(r, in.cookieCrypter) if err != nil { - return nil, fmt.Errorf("get: %w", err) + return nil, err } return in.GetForTicket(r.Context(), ticket) diff --git a/pkg/session/ticket.go b/pkg/session/ticket.go index c3039e2..614b18f 100644 --- a/pkg/session/ticket.go +++ b/pkg/session/ticket.go @@ -60,10 +60,10 @@ func (c *Ticket) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter c func getTicket(r *http.Request, crypter crypto.Crypter) (*Ticket, error) { ticketJson, err := cookie.GetDecrypted(r, cookie.Session, crypter) if errors.Is(err, http.ErrNoCookie) { - return nil, fmt.Errorf("ticket: session cookie: %w", ErrNotFound) + return nil, fmt.Errorf("ticket: cookie %w", ErrNotFound) } if errors.Is(err, cookie.ErrInvalidValue) || errors.Is(err, cookie.ErrDecrypt) { - return nil, fmt.Errorf("ticket: session cookie: %w: %w", ErrInvalid, err) + return nil, fmt.Errorf("ticket: cookie: %w: %w", ErrInvalid, err) } if err != nil { return nil, err