diff --git a/pkg/handler/handler_callback.go b/pkg/handler/handler_callback.go index 0698db8..37c1662 100644 --- a/pkg/handler/handler_callback.go +++ b/pkg/handler/handler_callback.go @@ -17,7 +17,6 @@ import ( ) func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) { - // unconditionally clear login cookie h.clearLoginCookies(w) diff --git a/pkg/handler/handler_logout.go b/pkg/handler/handler_logout.go index 14ae172..21a4b62 100644 --- a/pkg/handler/handler_logout.go +++ b/pkg/handler/handler_logout.go @@ -5,11 +5,11 @@ import ( "fmt" "net/http" - "github.com/go-redis/redis/v8" log "github.com/sirupsen/logrus" "github.com/nais/wonderwall/pkg/cookie" logentry "github.com/nais/wonderwall/pkg/middleware" + "github.com/nais/wonderwall/pkg/session" ) // Logout triggers self-initiated for the current user @@ -22,7 +22,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) { if err == nil && sessionData != nil { idToken = sessionData.IDToken err = h.destroySession(w, r, h.localSessionID(sessionData.ExternalSessionID)) - if err != nil && !errors.Is(err, redis.Nil) { + if err != nil && !errors.Is(err, session.KeyNotFoundError) { h.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err)) return } diff --git a/pkg/handler/session.go b/pkg/handler/session.go index cb6c6b4..2de9a9b 100644 --- a/pkg/handler/session.go +++ b/pkg/handler/session.go @@ -7,7 +7,6 @@ import ( "net/http" "time" - "github.com/go-redis/redis/v8" "github.com/sethvargo/go-retry" "github.com/nais/wonderwall/pkg/cookie" @@ -38,7 +37,7 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) ( return sessionData, nil } - if errors.Is(err, redis.Nil) { + if errors.Is(err, session.KeyNotFoundError) { return nil, fmt.Errorf("session not found in store: %w", err) } @@ -63,7 +62,7 @@ func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data, } err = fmt.Errorf("reading session data from store: %w", err) - if errors.Is(err, redis.Nil) { + if errors.Is(err, session.KeyNotFoundError) { return err } @@ -151,7 +150,7 @@ func (h *Handler) destroySession(w http.ResponseWriter, r *http.Request, session } err = fmt.Errorf("deleting session from store: %w", err) - if errors.Is(err, redis.Nil) { + if errors.Is(err, session.KeyNotFoundError) { return err } diff --git a/pkg/session/memory.go b/pkg/session/memory.go index 4615e72..8a8e6c9 100644 --- a/pkg/session/memory.go +++ b/pkg/session/memory.go @@ -26,7 +26,7 @@ func (s *memorySessionStore) Read(_ context.Context, key string) (*EncryptedData data, ok := s.sessions[key] if !ok { - return nil, fmt.Errorf("no such session: %s", key) + return nil, fmt.Errorf("%w: no such session: %s", KeyNotFoundError, key) } return data, nil diff --git a/pkg/session/redis.go b/pkg/session/redis.go index 5e79a29..cbc2f66 100644 --- a/pkg/session/redis.go +++ b/pkg/session/redis.go @@ -2,6 +2,8 @@ package session import ( "context" + "errors" + "fmt" "time" "github.com/go-redis/redis/v8" @@ -26,11 +28,15 @@ func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedDat err := metrics.ObserveRedisLatency("Read", func() error { return s.client.Get(ctx, key).Scan(encryptedData) }) - if err != nil { - return nil, err + if err == nil { + return encryptedData, nil } - return encryptedData, nil + if errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("%w: %s", KeyNotFoundError, err.Error()) + } + + return nil, err } func (s *redisSessionStore) Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error { @@ -40,7 +46,16 @@ func (s *redisSessionStore) Write(ctx context.Context, key string, value *Encryp } func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error { - return metrics.ObserveRedisLatency("Delete", func() error { + err := metrics.ObserveRedisLatency("Delete", func() error { return s.client.Del(ctx, keys...).Err() }) + if err == nil { + return nil + } + + if errors.Is(err, redis.Nil) { + return fmt.Errorf("%w: %s", KeyNotFoundError, err.Error()) + } + + return err } diff --git a/pkg/session/session.go b/pkg/session/session.go index d906294..64d1d60 100644 --- a/pkg/session/session.go +++ b/pkg/session/session.go @@ -2,6 +2,7 @@ package session import ( "context" + "errors" "time" log "github.com/sirupsen/logrus" @@ -9,6 +10,10 @@ import ( "github.com/nais/wonderwall/pkg/config" ) +var ( + KeyNotFoundError = errors.New("key not found") +) + type Store interface { Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error Read(ctx context.Context, key string) (*EncryptedData, error)