refactor(session): wrap own error type instead of using store-specific errors

This commit is contained in:
Trong Huu Nguyen
2022-07-18 19:16:28 +02:00
parent 4ab07e9dc2
commit b674a0ffa7
6 changed files with 30 additions and 12 deletions

View File

@@ -17,7 +17,6 @@ import (
)
func (h *Handler) Callback(w http.ResponseWriter, r *http.Request) {
// unconditionally clear login cookie
h.clearLoginCookies(w)

View File

@@ -5,11 +5,11 @@ import (
"fmt"
"net/http"
"github.com/go-redis/redis/v8"
log "github.com/sirupsen/logrus"
"github.com/nais/wonderwall/pkg/cookie"
logentry "github.com/nais/wonderwall/pkg/middleware"
"github.com/nais/wonderwall/pkg/session"
)
// Logout triggers self-initiated for the current user
@@ -22,7 +22,7 @@ func (h *Handler) Logout(w http.ResponseWriter, r *http.Request) {
if err == nil && sessionData != nil {
idToken = sessionData.IDToken
err = h.destroySession(w, r, h.localSessionID(sessionData.ExternalSessionID))
if err != nil && !errors.Is(err, redis.Nil) {
if err != nil && !errors.Is(err, session.KeyNotFoundError) {
h.InternalError(w, r, fmt.Errorf("logout: destroying session: %w", err))
return
}

View File

@@ -7,7 +7,6 @@ import (
"net/http"
"time"
"github.com/go-redis/redis/v8"
"github.com/sethvargo/go-retry"
"github.com/nais/wonderwall/pkg/cookie"
@@ -38,7 +37,7 @@ func (h *Handler) getSessionFromCookie(w http.ResponseWriter, r *http.Request) (
return sessionData, nil
}
if errors.Is(err, redis.Nil) {
if errors.Is(err, session.KeyNotFoundError) {
return nil, fmt.Errorf("session not found in store: %w", err)
}
@@ -63,7 +62,7 @@ func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data,
}
err = fmt.Errorf("reading session data from store: %w", err)
if errors.Is(err, redis.Nil) {
if errors.Is(err, session.KeyNotFoundError) {
return err
}
@@ -151,7 +150,7 @@ func (h *Handler) destroySession(w http.ResponseWriter, r *http.Request, session
}
err = fmt.Errorf("deleting session from store: %w", err)
if errors.Is(err, redis.Nil) {
if errors.Is(err, session.KeyNotFoundError) {
return err
}

View File

@@ -26,7 +26,7 @@ func (s *memorySessionStore) Read(_ context.Context, key string) (*EncryptedData
data, ok := s.sessions[key]
if !ok {
return nil, fmt.Errorf("no such session: %s", key)
return nil, fmt.Errorf("%w: no such session: %s", KeyNotFoundError, key)
}
return data, nil

View File

@@ -2,6 +2,8 @@ package session
import (
"context"
"errors"
"fmt"
"time"
"github.com/go-redis/redis/v8"
@@ -26,11 +28,15 @@ func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedDat
err := metrics.ObserveRedisLatency("Read", func() error {
return s.client.Get(ctx, key).Scan(encryptedData)
})
if err != nil {
return nil, err
if err == nil {
return encryptedData, nil
}
return encryptedData, nil
if errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("%w: %s", KeyNotFoundError, err.Error())
}
return nil, err
}
func (s *redisSessionStore) Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error {
@@ -40,7 +46,16 @@ func (s *redisSessionStore) Write(ctx context.Context, key string, value *Encryp
}
func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error {
return metrics.ObserveRedisLatency("Delete", func() error {
err := metrics.ObserveRedisLatency("Delete", func() error {
return s.client.Del(ctx, keys...).Err()
})
if err == nil {
return nil
}
if errors.Is(err, redis.Nil) {
return fmt.Errorf("%w: %s", KeyNotFoundError, err.Error())
}
return err
}

View File

@@ -2,6 +2,7 @@ package session
import (
"context"
"errors"
"time"
log "github.com/sirupsen/logrus"
@@ -9,6 +10,10 @@ import (
"github.com/nais/wonderwall/pkg/config"
)
var (
KeyNotFoundError = errors.New("key not found")
)
type Store interface {
Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error
Read(ctx context.Context, key string) (*EncryptedData, error)