From a639ff290335d62bfeee5007bdddbc85f96d9f4d Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Mon, 18 Jul 2022 16:51:00 +0200 Subject: [PATCH] refactor(retry): extract retry package, add retry for session operations --- pkg/handler/handler_callback.go | 23 +++------ pkg/handler/handler_frontchannellogout.go | 2 +- pkg/handler/session.go | 59 +++++++++++++++++++---- pkg/retry/retry.go | 50 +++++++++++++++++++ 4 files changed, 106 insertions(+), 28 deletions(-) create mode 100644 pkg/retry/retry.go diff --git a/pkg/handler/handler_callback.go b/pkg/handler/handler_callback.go index b02a389..8247af6 100644 --- a/pkg/handler/handler_callback.go +++ b/pkg/handler/handler_callback.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/http" - "time" "github.com/sethvargo/go-retry" log "github.com/sirupsen/logrus" @@ -14,14 +13,11 @@ import ( logentry "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/openid" "github.com/nais/wonderwall/pkg/openid/client" -) - -const ( - retryBaseDuration = 50 * time.Millisecond - retryMaxDuration = 1 * time.Second + retrypkg "github.com/nais/wonderwall/pkg/retry" ) func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { + // unconditionally clear login cookie h.clearLoginCookies(w) @@ -92,8 +88,7 @@ func (h *Handler) redeemValidTokens(r *http.Request, loginCallback client.LoginC return nil } - err = retry.Do(r.Context(), backoff(), retryable) - if err != nil { + if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { return nil, err } @@ -103,7 +98,7 @@ func (h *Handler) redeemValidTokens(r *http.Request, loginCallback client.LoginC func (h *Handler) getLoginstatusToken(r *http.Request, tokens *openid.Tokens) (*loginstatus.TokenResponse, error) { var tokenResponse *loginstatus.TokenResponse - err := retry.Do(r.Context(), backoff(), func(ctx context.Context) error { + retryable := func(ctx context.Context) error { var err error tokenResponse, err = h.Loginstatus.ExchangeToken(ctx, tokens.AccessToken) @@ -113,8 +108,8 @@ func (h *Handler) getLoginstatusToken(r *http.Request, tokens *openid.Tokens) (* } return nil - }) - if err != nil { + } + if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { return nil, err } @@ -129,9 +124,3 @@ func logSuccessfulLogin(r *http.Request, tokens *openid.Tokens, referer string) logentry.LogEntry(r).WithFields(fields).Info("callback: successful login") } - -func backoff() retry.Backoff { - b := retry.NewFibonacci(retryBaseDuration) - b = retry.WithMaxDuration(retryMaxDuration, b) - return b -} diff --git a/pkg/handler/handler_frontchannellogout.go b/pkg/handler/handler_frontchannellogout.go index dc17959..e532a82 100644 --- a/pkg/handler/handler_frontchannellogout.go +++ b/pkg/handler/handler_frontchannellogout.go @@ -26,7 +26,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) { sid := logoutFrontchannel.Sid() sessionID := h.localSessionID(sid) - sessionData, err := h.getSession(r.Context(), sessionID) + sessionData, err := h.getSession(r, sessionID) if err != nil { logentry.LogEntry(r).Infof("front-channel logout: getting session (user might already be logged out): %+v", err) } diff --git a/pkg/handler/session.go b/pkg/handler/session.go index 76f74d5..cb6c6b4 100644 --- a/pkg/handler/session.go +++ b/pkg/handler/session.go @@ -8,10 +8,12 @@ import ( "time" "github.com/go-redis/redis/v8" + "github.com/sethvargo/go-retry" "github.com/nais/wonderwall/pkg/cookie" logentry "github.com/nais/wonderwall/pkg/middleware" "github.com/nais/wonderwall/pkg/openid" + retrypkg "github.com/nais/wonderwall/pkg/retry" "github.com/nais/wonderwall/pkg/session" ) @@ -30,7 +32,7 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) ( return nil, fmt.Errorf("no session cookie: %w", err) } - sessionData, err := h.getSession(r.Context(), sessionID) + sessionData, err := h.getSession(r, sessionID) if err == nil { h.DeleteSessionFallback(w, r) return sessionData, nil @@ -50,10 +52,26 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) ( return fallbackSessionData, nil } -func (h *Handler) getSession(ctx context.Context, sessionID string) (*session.Data, error) { - encryptedSessionData, err := h.Sessions.Read(ctx, sessionID) - if err != nil { - return nil, fmt.Errorf("reading session data from store: %w", err) +func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data, error) { + var encryptedSessionData *session.EncryptedData + var err error + + retryable := func(ctx context.Context) error { + encryptedSessionData, err = h.Sessions.Read(ctx, sessionID) + if err == nil { + return nil + } + + err = fmt.Errorf("reading session data from store: %w", err) + if errors.Is(err, redis.Nil) { + return err + } + + return retry.RetryableError(err) + } + + if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { + return nil, err } sessionData, err := encryptedSessionData.Decrypt(h.Crypter) @@ -101,8 +119,16 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens * return fmt.Errorf("encrypting session data: %w", err) } - err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime) - if err == nil { + retryable := func(ctx context.Context) error { + err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime) + if err != nil { + return retry.RetryableError(err) + } + + return nil + } + + if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err == nil { h.DeleteSessionFallback(w, r) return nil } @@ -118,9 +144,22 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens * } func (h *Handler) destroySession(w http.ResponseWriter, r *http.Request, sessionID string) error { - err := h.Sessions.Delete(r.Context(), sessionID) - if err != nil { - return fmt.Errorf("deleting session from store: %w", err) + retryable := func(ctx context.Context) error { + err := h.Sessions.Delete(r.Context(), sessionID) + if err == nil { + return nil + } + + err = fmt.Errorf("deleting session from store: %w", err) + if errors.Is(err, redis.Nil) { + return err + } + + return retry.RetryableError(err) + } + + if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { + return err } h.DeleteSessionFallback(w, r) diff --git a/pkg/retry/retry.go b/pkg/retry/retry.go new file mode 100644 index 0000000..d6e1c73 --- /dev/null +++ b/pkg/retry/retry.go @@ -0,0 +1,50 @@ +package retry + +import ( + "time" + + "github.com/sethvargo/go-retry" +) + +const ( + DefaultBaseDuration = 50 * time.Millisecond + DefaultMaxDuration = 1 * time.Second +) + +var DefaultBackoff = Fibonacci().Backoff() + +type FibonacciBackoff struct { + base time.Duration + max time.Duration + backoff retry.Backoff +} + +func (in FibonacciBackoff) WithBase(base time.Duration) FibonacciBackoff { + in.base = base + in.backoff = fibonacci(in.base, in.max) + return in +} + +func (in FibonacciBackoff) WithMax(max time.Duration) FibonacciBackoff { + in.max = max + in.backoff = fibonacci(in.base, in.max) + return in +} + +func (in FibonacciBackoff) Backoff() retry.Backoff { + return in.backoff +} + +func Fibonacci() FibonacciBackoff { + return FibonacciBackoff{ + base: DefaultBaseDuration, + max: DefaultMaxDuration, + backoff: fibonacci(DefaultBaseDuration, DefaultMaxDuration), + } +} + +func fibonacci(base, max time.Duration) retry.Backoff { + b := retry.NewFibonacci(base) + b = retry.WithMaxDuration(max, b) + return b +}