feat(session): add GetOrRefresh method

This commit is contained in:
Trong Huu Nguyen
2023-02-21 14:12:56 +01:00
parent 7a52b0d1a3
commit 5b33313ccb
4 changed files with 29 additions and 3 deletions

View File

@@ -42,6 +42,10 @@ type Writer interface {
type Manager interface {
Reader
Writer
// GetOrRefresh returns the session for a given http.Request. If the tokens within the session are expired and the
// session is still valid, it will automatically attempt to refresh and update the session.
GetOrRefresh(r *http.Request) (*Session, error)
}
type Session struct {

View File

@@ -102,6 +102,28 @@ func (in *manager) DeleteForExternalID(ctx context.Context, id string) error {
return in.deleteForKey(ctx, key)
}
func (in *manager) GetOrRefresh(r *http.Request) (*Session, error) {
sess, err := in.Get(r)
if err != nil {
return nil, fmt.Errorf("getting session: %w", err)
}
if !sess.ShouldRefresh() {
return sess, nil
}
refreshed, err := in.Refresh(r, sess)
if errors.Is(err, ErrInvalidExternal) || errors.Is(err, ErrInvalid) {
return nil, err
} else if err != nil {
mw.LogEntryFrom(r).Warnf("session: could not refresh tokens; falling back to existing tokens: %+v", err)
} else {
sess = refreshed
}
return sess, nil
}
func (in *manager) Refresh(r *http.Request, sess *Session) (*Session, error) {
if !in.cfg.Session.Refresh || !sess.CanRefresh() {
return sess, nil

View File

@@ -37,7 +37,7 @@ func NewReader(cfg *config.Config, cookieCrypter crypto.Crypter) (Reader, error)
func (in *reader) Get(r *http.Request) (*Session, error) {
ticket, err := getTicket(r, in.cookieCrypter)
if err != nil {
return nil, fmt.Errorf("get: %w", err)
return nil, err
}
return in.GetForTicket(r.Context(), ticket)

View File

@@ -60,10 +60,10 @@ func (c *Ticket) SetCookie(w http.ResponseWriter, opts cookie.Options, crypter c
func getTicket(r *http.Request, crypter crypto.Crypter) (*Ticket, error) {
ticketJson, err := cookie.GetDecrypted(r, cookie.Session, crypter)
if errors.Is(err, http.ErrNoCookie) {
return nil, fmt.Errorf("ticket: session cookie: %w", ErrNotFound)
return nil, fmt.Errorf("ticket: cookie %w", ErrNotFound)
}
if errors.Is(err, cookie.ErrInvalidValue) || errors.Is(err, cookie.ErrDecrypt) {
return nil, fmt.Errorf("ticket: session cookie: %w: %w", ErrInvalid, err)
return nil, fmt.Errorf("ticket: cookie: %w: %w", ErrInvalid, err)
}
if err != nil {
return nil, err