mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-23 00:32:55 +00:00
refactor: clean up session error handling
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/nais/wonderwall/pkg/handler/url"
|
||||
mw "github.com/nais/wonderwall/pkg/middleware"
|
||||
"github.com/nais/wonderwall/pkg/session"
|
||||
)
|
||||
|
||||
// Default proxies all requests upstream.
|
||||
@@ -12,7 +16,7 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
|
||||
logger := mw.LogEntry(r).WithField("request_path", r.URL.Path)
|
||||
isAuthenticated := false
|
||||
|
||||
accessToken, ok := h.accessToken(r)
|
||||
accessToken, ok := h.accessToken(r, logger)
|
||||
if ok {
|
||||
// add authentication if session cookie and token checks out
|
||||
isAuthenticated = true
|
||||
@@ -44,11 +48,15 @@ func (h *Handler) Default(w http.ResponseWriter, r *http.Request) {
|
||||
h.ReverseProxy.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
func (h *Handler) accessToken(r *http.Request) (string, bool) {
|
||||
func (h *Handler) accessToken(r *http.Request, logger *log.Entry) (string, bool) {
|
||||
sessionData, err := h.getSessionFromCookie(r)
|
||||
if err != nil || sessionData == nil || len(sessionData.AccessToken) == 0 {
|
||||
return "", false
|
||||
if err == nil && sessionData != nil && len(sessionData.AccessToken) > 0 {
|
||||
return sessionData.AccessToken, true
|
||||
}
|
||||
|
||||
return sessionData.AccessToken, true
|
||||
if errors.Is(err, session.UnexpectedError) {
|
||||
logger.Errorf("default: getting session: %+v", err)
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -36,10 +36,10 @@ func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) {
|
||||
}
|
||||
|
||||
if errors.Is(err, session.KeyNotFoundError) {
|
||||
return nil, fmt.Errorf("session not found in store: %w", err)
|
||||
return nil, fmt.Errorf("session not found: %w", err)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("get session: store is unavailable: %+v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data, error) {
|
||||
@@ -52,7 +52,6 @@ func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data,
|
||||
return nil
|
||||
}
|
||||
|
||||
err = fmt.Errorf("reading session data from store: %w", err)
|
||||
if errors.Is(err, session.KeyNotFoundError) {
|
||||
return err
|
||||
}
|
||||
@@ -61,7 +60,7 @@ func (h *Handler) getSession(r *http.Request, sessionID string) (*session.Data,
|
||||
}
|
||||
|
||||
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("reading from store: %w", err)
|
||||
}
|
||||
|
||||
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
|
||||
@@ -122,7 +121,7 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, tokens *
|
||||
}
|
||||
|
||||
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
|
||||
return fmt.Errorf("create session: store is unavailable: %+v", err)
|
||||
return fmt.Errorf("writing to store: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -135,7 +134,6 @@ func (h *Handler) destroySession(r *http.Request, sessionID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = fmt.Errorf("deleting session from store: %w", err)
|
||||
if errors.Is(err, session.KeyNotFoundError) {
|
||||
return err
|
||||
}
|
||||
@@ -144,7 +142,7 @@ func (h *Handler) destroySession(r *http.Request, sessionID string) error {
|
||||
}
|
||||
|
||||
if err := retry.Do(r.Context(), retrypkg.DefaultBackoff, retryable); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("deleting from store: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -52,5 +52,6 @@ func TestMemory(t *testing.T) {
|
||||
|
||||
result, err = sess.Read(context.Background(), "key")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, session.KeyNotFoundError)
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
@@ -36,13 +36,18 @@ func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedDat
|
||||
return nil, fmt.Errorf("%w: %s", KeyNotFoundError, err.Error())
|
||||
}
|
||||
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("%w: %s", UnexpectedError, err.Error())
|
||||
}
|
||||
|
||||
func (s *redisSessionStore) Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error {
|
||||
return metrics.ObserveRedisLatency(metrics.RedisOperationWrite, func() error {
|
||||
err := metrics.ObserveRedisLatency(metrics.RedisOperationWrite, func() error {
|
||||
return s.client.Set(ctx, key, value, expiration).Err()
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %s", UnexpectedError, err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error {
|
||||
@@ -57,5 +62,5 @@ func (s *redisSessionStore) Delete(ctx context.Context, keys ...string) error {
|
||||
return fmt.Errorf("%w: %s", KeyNotFoundError, err.Error())
|
||||
}
|
||||
|
||||
return err
|
||||
return fmt.Errorf("%w: %s", UnexpectedError, err.Error())
|
||||
}
|
||||
|
||||
@@ -65,5 +65,6 @@ func TestRedis(t *testing.T) {
|
||||
|
||||
result, err = sess.Read(context.Background(), "key")
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, session.KeyNotFoundError)
|
||||
assert.Nil(t, result)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
var (
|
||||
KeyNotFoundError = errors.New("key not found")
|
||||
UnexpectedError = errors.New("unexpected error")
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
|
||||
Reference in New Issue
Block a user