mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-11 10:56:53 +00:00
refactor(retry): extract retry package, add retry for session operations
This commit is contained in:
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/sethvargo/go-retry"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -14,14 +13,11 @@ import (
|
||||
logentry "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
"github.com/nais/wonderwall/pkg/openid/client"
|
||||
)
|
||||
|
||||
const (
|
||||
retryBaseDuration = 50 * time.Millisecond
|
||||
retryMaxDuration = 1 * time.Second
|
||||
retrypkg "github.com/nais/wonderwall/pkg/retry"
|
||||
)
|
||||
|
||||
func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// unconditionally clear login cookie
|
||||
h.clearLoginCookies(w)
|
||||
|
||||
@@ -92,8 +88,7 @@ func (h *Handler) redeemValidTokens(r *http.Request, loginCallback client.LoginC
|
||||
return nil
|
||||
}
|
||||
|
||||
err = retry.Do(r.Context(), backoff(), retryable)
|
||||
if err != nil {
|
||||
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -103,7 +98,7 @@ func (h *Handler) redeemValidTokens(r *http.Request, loginCallback client.LoginC
|
||||
func (h *Handler) getLoginstatusToken(r *http.Request, tokens *openid.Tokens) (*loginstatus.TokenResponse, error) {
|
||||
var tokenResponse *loginstatus.TokenResponse
|
||||
|
||||
err := retry.Do(r.Context(), backoff(), func(ctx context.Context) error {
|
||||
retryable := func(ctx context.Context) error {
|
||||
var err error
|
||||
|
||||
tokenResponse, err = h.Loginstatus.ExchangeToken(ctx, tokens.AccessToken)
|
||||
@@ -113,8 +108,8 @@ func (h *Handler) getLoginstatusToken(r *http.Request, tokens *openid.Tokens) (*
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
}
|
||||
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -129,9 +124,3 @@ func logSuccessfulLogin(r *http.Request, tokens *openid.Tokens, referer string)
|
||||
|
||||
logentry.LogEntry(r).WithFields(fields).Info("callback: successful login")
|
||||
}
|
||||
|
||||
func backoff() retry.Backoff {
|
||||
b := retry.NewFibonacci(retryBaseDuration)
|
||||
b = retry.WithMaxDuration(retryMaxDuration, b)
|
||||
return b
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
sid := logoutFrontchannel.Sid()
|
||||
sessionID := h.localSessionID(sid)
|
||||
sessionData, err := h.getSession(r.Context(), sessionID)
|
||||
sessionData, err := h.getSession(r, sessionID)
|
||||
if err != nil {
|
||||
logentry.LogEntry(r).Infof("front-channel logout: getting session (user might already be logged out): %+v", err)
|
||||
}
|
||||
|
||||
@@ -8,10 +8,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/sethvargo/go-retry"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/cookie"
|
||||
logentry "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/openid"
|
||||
retrypkg "github.com/nais/wonderwall/pkg/retry"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
@@ -30,7 +32,7 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (
|
||||
return nil, fmt.Errorf("no session cookie: %w", err)
|
||||
}
|
||||
|
||||
sessionData, err := h.getSession(r.Context(), sessionID)
|
||||
sessionData, err := h.getSession(r, sessionID)
|
||||
if err == nil {
|
||||
h.DeleteSessionFallback(w, r)
|
||||
return sessionData, nil
|
||||
@@ -50,10 +52,26 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (
|
||||
return fallbackSessionData, nil
|
||||
}
|
||||
|
||||
func (h *Handler) getSession(ctx context.Context, sessionID string) (*session.Data, error) {
|
||||
encryptedSessionData, err := h.Sessions.Read(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading session data from store: %w", err)
|
||||
func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data, error) {
|
||||
var encryptedSessionData *session.EncryptedData
|
||||
var err error
|
||||
|
||||
retryable := func(ctx context.Context) error {
|
||||
encryptedSessionData, err = h.Sessions.Read(ctx, sessionID)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = fmt.Errorf("reading session data from store: %w", err)
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return err
|
||||
}
|
||||
|
||||
return retry.RetryableError(err)
|
||||
}
|
||||
|
||||
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
|
||||
@@ -101,8 +119,16 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *
|
||||
return fmt.Errorf("encrypting session data: %w", err)
|
||||
}
|
||||
|
||||
err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime)
|
||||
if err == nil {
|
||||
retryable := func(ctx context.Context) error {
|
||||
err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime)
|
||||
if err != nil {
|
||||
return retry.RetryableError(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err == nil {
|
||||
h.DeleteSessionFallback(w, r)
|
||||
return nil
|
||||
}
|
||||
@@ -118,9 +144,22 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *
|
||||
}
|
||||
|
||||
func (h *Handler) destroySession(w http.ResponseWriter, r *http.Request, sessionID string) error {
|
||||
err := h.Sessions.Delete(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("deleting session from store: %w", err)
|
||||
retryable := func(ctx context.Context) error {
|
||||
err := h.Sessions.Delete(r.Context(), sessionID)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = fmt.Errorf("deleting session from store: %w", err)
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return err
|
||||
}
|
||||
|
||||
return retry.RetryableError(err)
|
||||
}
|
||||
|
||||
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.DeleteSessionFallback(w, r)
|
||||
|
||||
50
pkg/retry/retry.go
Normal file
50
pkg/retry/retry.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package retry
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sethvargo/go-retry"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultBaseDuration = 50 * time.Millisecond
|
||||
DefaultMaxDuration = 1 * time.Second
|
||||
)
|
||||
|
||||
var DefaultBackoff = Fibonacci().Backoff()
|
||||
|
||||
type FibonacciBackoff struct {
|
||||
base time.Duration
|
||||
max time.Duration
|
||||
backoff retry.Backoff
|
||||
}
|
||||
|
||||
func (in FibonacciBackoff) WithBase(base time.Duration) FibonacciBackoff {
|
||||
in.base = base
|
||||
in.backoff = fibonacci(in.base, in.max)
|
||||
return in
|
||||
}
|
||||
|
||||
func (in FibonacciBackoff) WithMax(max time.Duration) FibonacciBackoff {
|
||||
in.max = max
|
||||
in.backoff = fibonacci(in.base, in.max)
|
||||
return in
|
||||
}
|
||||
|
||||
func (in FibonacciBackoff) Backoff() retry.Backoff {
|
||||
return in.backoff
|
||||
}
|
||||
|
||||
func Fibonacci() FibonacciBackoff {
|
||||
return FibonacciBackoff{
|
||||
base: DefaultBaseDuration,
|
||||
max: DefaultMaxDuration,
|
||||
backoff: fibonacci(DefaultBaseDuration, DefaultMaxDuration),
|
||||
}
|
||||
}
|
||||
|
||||
func fibonacci(base, max time.Duration) retry.Backoff {
|
||||
b := retry.NewFibonacci(base)
|
||||
b = retry.WithMaxDuration(max, b)
|
||||
return b
|
||||
}
|
||||
Reference in New Issue
Block a user