refactor: clean up session error handling

This commit is contained in:
Trong Huu Nguyen
2022-08-18 21:35:15 +02:00
parent ae8028cc96
commit c15e00469b
6 changed files with 29 additions and 15 deletions

View File

@@ -1,10 +1,14 @@
package handler
import (
"errors"
"net/http"
log "github.com/sirupsen/logrus"
"github.com/nais/wonderwall/pkg/handler/url"
mw "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/session"
)
// Default proxies all requests upstream.
@@ -12,7 +16,7 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
logger := mw.LogEntry(r).WithField("request_path", r.URL.Path)
isAuthenticated := false
accessToken, ok := h.accessToken(r)
accessToken, ok := h.accessToken(r, logger)
if ok {
// add authentication if session cookie and token checks out
isAuthenticated = true
@@ -44,11 +48,15 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
h.ReverseProxy.ServeHTTP(w, r.WithContext(ctx))
}
func (h *Handler) accessToken(r *http.Request) (string, bool) {
func (h *Handler) accessToken(r *http.Request, logger *log.Entry) (string, bool) {
sessionData, err := h.getSessionFromCookie(r)
if err != nil || sessionData == nil || len(sessionData.AccessToken) == 0 {
return "", false
if err == nil && sessionData != nil && len(sessionData.AccessToken) > 0 {
return sessionData.AccessToken, true
}
return sessionData.AccessToken, true
if errors.Is(err, session.UnexpectedError) {
logger.Errorf("default: getting session: %+v", err)
}
return "", false
}

View File

@@ -36,10 +36,10 @@ func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) {
}
if errors.Is(err, session.KeyNotFoundError) {
return nil, fmt.Errorf("session not found in store: %w", err)
return nil, fmt.Errorf("session not found: %w", err)
}
return nil, fmt.Errorf("get session: store is unavailable: %+v", err)
return nil, err
}
func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data, error) {
@@ -52,7 +52,6 @@ func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data,
return nil
}
err = fmt.Errorf("reading session data from store: %w", err)
if errors.Is(err, session.KeyNotFoundError) {
return err
}
@@ -61,7 +60,7 @@ func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data,
}
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
return nil, err
return nil, fmt.Errorf("reading from store: %w", err)
}
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
@@ -122,7 +121,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *
}
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
return fmt.Errorf("create session: store is unavailable: %+v", err)
return fmt.Errorf("writing to store: %w", err)
}
return nil
@@ -135,7 +134,6 @@ func (h *Handler) destroySession(r *http.Request, sessionID string) error {
return nil
}
err = fmt.Errorf("deleting session from store: %w", err)
if errors.Is(err, session.KeyNotFoundError) {
return err
}
@@ -144,7 +142,7 @@ func (h *Handler) destroySession(r *http.Request, sessionID string) error {
}
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
return err
return fmt.Errorf("deleting from store: %w", err)
}
return nil

View File

@@ -52,5 +52,6 @@ func TestMemory(t *testing.T) {
result, err = sess.Read(context.Background(), "key")
assert.Error(t, err)
assert.ErrorIs(t, err, session.KeyNotFoundError)
assert.Nil(t, result)
}

View File

@@ -36,13 +36,18 @@ func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedDat
return nil, fmt.Errorf("%w: %s", KeyNotFoundError, err.Error())
}
return nil, err
return nil, fmt.Errorf("%w: %s", UnexpectedError, err.Error())
}
func (s *redisSessionStore) Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error {
return metrics.ObserveRedisLatency(metrics.RedisOperationWrite, func() error {
err := metrics.ObserveRedisLatency(metrics.RedisOperationWrite, func() error {
return s.client.Set(ctx, key, value, expiration).Err()
})
if err != nil {
return fmt.Errorf("%w: %s", UnexpectedError, err.Error())
}
return nil
}
func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error {
@@ -57,5 +62,5 @@ func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error {
return fmt.Errorf("%w: %s", KeyNotFoundError, err.Error())
}
return err
return fmt.Errorf("%w: %s", UnexpectedError, err.Error())
}

View File

@@ -65,5 +65,6 @@ func TestRedis(t *testing.T) {
result, err = sess.Read(context.Background(), "key")
assert.Error(t, err)
assert.ErrorIs(t, err, session.KeyNotFoundError)
assert.Nil(t, result)
}

View File

@@ -13,6 +13,7 @@ import (
var (
KeyNotFoundError = errors.New("key not found")
UnexpectedError = errors.New("unexpected error")
)
type Store interface {