refactor(retry): extract retry package, add retry for session operations

This commit is contained in:
Trong Huu Nguyen
2022-07-18 16:51:00 +02:00
parent 81fa96ccb8
commit a639ff2903
4 changed files with 106 additions and 28 deletions

View File

@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"net/http"
"time"
"github.com/sethvargo/go-retry"
log "github.com/sirupsen/logrus"
@@ -14,14 +13,11 @@ import (
logentry "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/openid"
"github.com/nais/wonderwall/pkg/openid/client"
)
const (
retryBaseDuration = 50 * time.Millisecond
retryMaxDuration = 1 * time.Second
retrypkg "github.com/nais/wonderwall/pkg/retry"
)
func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
// unconditionally clear login cookie
h.clearLoginCookies(w)
@@ -92,8 +88,7 @@ func (h *Handler) redeemValidTokens(r *http.Request, loginCallback client.LoginC
return nil
}
err = retry.Do(r.Context(), backoff(), retryable)
if err != nil {
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
return nil, err
}
@@ -103,7 +98,7 @@ func (h *Handler) redeemValidTokens(r *http.Request, loginCallback client.LoginC
func (h *Handler) getLoginstatusToken(r *http.Request, tokens *openid.Tokens) (*loginstatus.TokenResponse, error) {
var tokenResponse *loginstatus.TokenResponse
err := retry.Do(r.Context(), backoff(), func(ctx context.Context) error {
retryable := func(ctx context.Context) error {
var err error
tokenResponse, err = h.Loginstatus.ExchangeToken(ctx, tokens.AccessToken)
@@ -113,8 +108,8 @@ func (h *Handler) getLoginstatusToken(r *http.Request, tokens *openid.Tokens) (*
}
return nil
})
if err != nil {
}
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
return nil, err
}
@@ -129,9 +124,3 @@ func logSuccessfulLogin(r *http.Request, tokens *openid.Tokens, referer string)
logentry.LogEntry(r).WithFields(fields).Info("callback: successful login")
}
func backoff() retry.Backoff {
b := retry.NewFibonacci(retryBaseDuration)
b = retry.WithMaxDuration(retryMaxDuration, b)
return b
}

View File

@@ -26,7 +26,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
sid := logoutFrontchannel.Sid()
sessionID := h.localSessionID(sid)
sessionData, err := h.getSession(r.Context(), sessionID)
sessionData, err := h.getSession(r, sessionID)
if err != nil {
logentry.LogEntry(r).Infof("front-channel logout: getting session (user might already be logged out): %+v", err)
}

View File

@@ -8,10 +8,12 @@ import (
"time"
"github.com/go-redis/redis/v8"
"github.com/sethvargo/go-retry"
"github.com/nais/wonderwall/pkg/cookie"
logentry "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/openid"
retrypkg "github.com/nais/wonderwall/pkg/retry"
"github.com/nais/wonderwall/pkg/session"
)
@@ -30,7 +32,7 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (
return nil, fmt.Errorf("no session cookie: %w", err)
}
sessionData, err := h.getSession(r.Context(), sessionID)
sessionData, err := h.getSession(r, sessionID)
if err == nil {
h.DeleteSessionFallback(w, r)
return sessionData, nil
@@ -50,10 +52,26 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (
return fallbackSessionData, nil
}
func (h *Handler) getSession(ctx context.Context, sessionID string) (*session.Data, error) {
encryptedSessionData, err := h.Sessions.Read(ctx, sessionID)
if err != nil {
return nil, fmt.Errorf("reading session data from store: %w", err)
func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data, error) {
var encryptedSessionData *session.EncryptedData
var err error
retryable := func(ctx context.Context) error {
encryptedSessionData, err = h.Sessions.Read(ctx, sessionID)
if err == nil {
return nil
}
err = fmt.Errorf("reading session data from store: %w", err)
if errors.Is(err, redis.Nil) {
return err
}
return retry.RetryableError(err)
}
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
return nil, err
}
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
@@ -101,8 +119,16 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *
return fmt.Errorf("encrypting session data: %w", err)
}
err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime)
if err == nil {
retryable := func(ctx context.Context) error {
err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime)
if err != nil {
return retry.RetryableError(err)
}
return nil
}
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err == nil {
h.DeleteSessionFallback(w, r)
return nil
}
@@ -118,9 +144,22 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *
}
func (h *Handler) destroySession(w http.ResponseWriter, r *http.Request, sessionID string) error {
err := h.Sessions.Delete(r.Context(), sessionID)
if err != nil {
return fmt.Errorf("deleting session from store: %w", err)
retryable := func(ctx context.Context) error {
err := h.Sessions.Delete(r.Context(), sessionID)
if err == nil {
return nil
}
err = fmt.Errorf("deleting session from store: %w", err)
if errors.Is(err, redis.Nil) {
return err
}
return retry.RetryableError(err)
}
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
return err
}
h.DeleteSessionFallback(w, r)

50
pkg/retry/retry.go Normal file
View File

@@ -0,0 +1,50 @@
package retry
import (
"time"
"github.com/sethvargo/go-retry"
)
const (
DefaultBaseDuration = 50 * time.Millisecond
DefaultMaxDuration = 1 * time.Second
)
var DefaultBackoff = Fibonacci().Backoff()
type FibonacciBackoff struct {
base time.Duration
max time.Duration
backoff retry.Backoff
}
func (in FibonacciBackoff) WithBase(base time.Duration) FibonacciBackoff {
in.base = base
in.backoff = fibonacci(in.base, in.max)
return in
}
func (in FibonacciBackoff) WithMax(max time.Duration) FibonacciBackoff {
in.max = max
in.backoff = fibonacci(in.base, in.max)
return in
}
func (in FibonacciBackoff) Backoff() retry.Backoff {
return in.backoff
}
func Fibonacci() FibonacciBackoff {
return FibonacciBackoff{
base: DefaultBaseDuration,
max: DefaultMaxDuration,
backoff: fibonacci(DefaultBaseDuration, DefaultMaxDuration),
}
}
func fibonacci(base, max time.Duration) retry.Backoff {
b := retry.NewFibonacci(base)
b = retry.WithMaxDuration(max, b)
return b
}