From c15e00469bce466c6cba0d84a55df518a16dea88 Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Thu, 18 Aug 2022 21:35:15 +0200 Subject: [PATCH] refactor: clean up session error handling --- pkg/handler/handler_default.go | 18 +++++++++++++----- pkg/handler/session.go | 12 +++++------- pkg/session/memory_test.go | 1 + pkg/session/redis.go | 11 ++++++++--- pkg/session/redis_test.go | 1 + pkg/session/session.go | 1 + 6 files changed, 29 insertions(+), 15 deletions(-) diff --git a/pkg/handler/handler_default.go b/pkg/handler/handler_default.go index b3deb48..8873f7d 100644 --- a/pkg/handler/handler_default.go +++ b/pkg/handler/handler_default.go @@ -1,10 +1,14 @@ package handler import ( + "errors" "net/http" + log "github.com/sirupsen/logrus" + "github.com/nais/wonderwall/pkg/handler/url" mw "github.com/nais/wonderwall/pkg/middleware" + "github.com/nais/wonderwall/pkg/session" ) // Default proxies all requests upstream. @@ -12,7 +16,7 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { logger := mw.LogEntry(r).WithField("request_path", r.URL.Path) isAuthenticated := false - accessToken, ok := h.accessToken(r) + accessToken, ok := h.accessToken(r, logger) if ok { // add authentication if session cookie and token checks out isAuthenticated = true @@ -44,11 +48,15 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) { h.ReverseProxy.ServeHTTP(w, r.WithContext(ctx)) } -func (h *Handler) accessToken(r *http.Request) (string, bool) { +func (h *Handler) accessToken(r *http.Request, logger *log.Entry) (string, bool) { sessionData, err := h.getSessionFromCookie(r) - if err != nil || sessionData == nil || len(sessionData.AccessToken) == 0 { - return "", false + if err == nil && sessionData != nil && len(sessionData.AccessToken) > 0 { + return sessionData.AccessToken, true } - return sessionData.AccessToken, true + if errors.Is(err, session.UnexpectedError) { + logger.Errorf("default: getting session: %+v", err) + } + + return "", false } diff --git a/pkg/handler/session.go b/pkg/handler/session.go index 5880ff9..49e4580 100644 --- a/pkg/handler/session.go +++ b/pkg/handler/session.go @@ -36,10 +36,10 @@ func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) { } if errors.Is(err, session.KeyNotFoundError) { - return nil, fmt.Errorf("session not found in store: %w", err) + return nil, fmt.Errorf("session not found: %w", err) } - return nil, fmt.Errorf("get session: store is unavailable: %+v", err) + return nil, err } func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data, error) { @@ -52,7 +52,6 @@ func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data, return nil } - err = fmt.Errorf("reading session data from store: %w", err) if errors.Is(err, session.KeyNotFoundError) { return err } @@ -61,7 +60,7 @@ func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data, } if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { - return nil, err + return nil, fmt.Errorf("reading from store: %w", err) } sessionData, err := encryptedSessionData.Decrypt(h.Crypter) @@ -122,7 +121,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens * } if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { - return fmt.Errorf("create session: store is unavailable: %+v", err) + return fmt.Errorf("writing to store: %w", err) } return nil @@ -135,7 +134,6 @@ func (h *Handler) destroySession(r *http.Request, sessionID string) error { return nil } - err = fmt.Errorf("deleting session from store: %w", err) if errors.Is(err, session.KeyNotFoundError) { return err } @@ -144,7 +142,7 @@ func (h *Handler) destroySession(r *http.Request, sessionID string) error { } if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil { - return err + return fmt.Errorf("deleting from store: %w", err) } return nil diff --git a/pkg/session/memory_test.go b/pkg/session/memory_test.go index cbef10a..44d817d 100644 --- a/pkg/session/memory_test.go +++ b/pkg/session/memory_test.go @@ -52,5 +52,6 @@ func TestMemory(t *testing.T) { result, err = sess.Read(context.Background(), "key") assert.Error(t, err) + assert.ErrorIs(t, err, session.KeyNotFoundError) assert.Nil(t, result) } diff --git a/pkg/session/redis.go b/pkg/session/redis.go index 255ad03..a475695 100644 --- a/pkg/session/redis.go +++ b/pkg/session/redis.go @@ -36,13 +36,18 @@ func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedDat return nil, fmt.Errorf("%w: %s", KeyNotFoundError, err.Error()) } - return nil, err + return nil, fmt.Errorf("%w: %s", UnexpectedError, err.Error()) } func (s *redisSessionStore) Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error { - return metrics.ObserveRedisLatency(metrics.RedisOperationWrite, func() error { + err := metrics.ObserveRedisLatency(metrics.RedisOperationWrite, func() error { return s.client.Set(ctx, key, value, expiration).Err() }) + if err != nil { + return fmt.Errorf("%w: %s", UnexpectedError, err.Error()) + } + + return nil } func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error { @@ -57,5 +62,5 @@ func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error { return fmt.Errorf("%w: %s", KeyNotFoundError, err.Error()) } - return err + return fmt.Errorf("%w: %s", UnexpectedError, err.Error()) } diff --git a/pkg/session/redis_test.go b/pkg/session/redis_test.go index 4728280..442e9d4 100644 --- a/pkg/session/redis_test.go +++ b/pkg/session/redis_test.go @@ -65,5 +65,6 @@ func TestRedis(t *testing.T) { result, err = sess.Read(context.Background(), "key") assert.Error(t, err) + assert.ErrorIs(t, err, session.KeyNotFoundError) assert.Nil(t, result) } diff --git a/pkg/session/session.go b/pkg/session/session.go index d9ee092..ebdb67e 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -13,6 +13,7 @@ import ( var ( KeyNotFoundError = errors.New("key not found") + UnexpectedError = errors.New("unexpected error") ) type Store interface {