mirror of
https://github.com/nais/wonderwall.git
synced 2026-05-20 07:12:48 +00:00
refactor: deduplicate crypto operations for sessions
This commit is contained in:
@@ -79,10 +79,10 @@ func run() error {
|
||||
return fmt.Errorf("connecting to redis: %w", err)
|
||||
}
|
||||
|
||||
sessionStore = session.NewRedis(redisClient, crypt)
|
||||
sessionStore = session.NewRedis(redisClient)
|
||||
log.Infof("Using Redis as session backing store")
|
||||
} else {
|
||||
sessionStore = session.NewMemory(crypt)
|
||||
sessionStore = session.NewMemory()
|
||||
log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!")
|
||||
}
|
||||
|
||||
|
||||
@@ -372,22 +372,20 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
sessionID := h.localSessionID(sid)
|
||||
|
||||
sess, err := h.Sessions.Read(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
// From here on, check that 'iss' from request matches data found in access token.
|
||||
accessToken, err := h.getAccessTokenFromSession(r, sessionID)
|
||||
if accessToken == nil {
|
||||
// Can't remove session because it doesn't exist. Maybe it was garbage collected.
|
||||
// We regard this as a redundant logout and return 200 OK.
|
||||
return
|
||||
}
|
||||
|
||||
// From here on, check that 'iss' from request matches data found in access token.
|
||||
tok, err := jwt.Parse([]byte(sess.OAuth2Token.AccessToken))
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
err = jwt.Validate(tok, jwt.WithClaimValue("iss", iss))
|
||||
err = jwt.Validate(accessToken, jwt.WithClaimValue("iss", iss))
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
|
||||
@@ -87,7 +87,7 @@ func handler(cfg config.IDPorten) *router.Handler {
|
||||
}
|
||||
|
||||
crypter := cryptutil.New(encryptionKey)
|
||||
sessionStore := session.NewMemory(crypter)
|
||||
sessionStore := session.NewMemory()
|
||||
|
||||
handler, err := router.NewHandler(cfg, crypter, jwkSet, sessionStore, "")
|
||||
if err != nil {
|
||||
|
||||
@@ -25,11 +25,16 @@ func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) {
|
||||
return nil, fmt.Errorf("no session cookie: %w", err)
|
||||
}
|
||||
|
||||
sessionData, err := h.Sessions.Read(r.Context(), sessionID)
|
||||
encryptedSessionData, err := h.Sessions.Read(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading session from store: %w", err)
|
||||
}
|
||||
|
||||
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypting session data: %w", err)
|
||||
}
|
||||
|
||||
return sessionData, nil
|
||||
}
|
||||
|
||||
@@ -69,10 +74,36 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external
|
||||
IDTokenSerialized: idToken.Raw,
|
||||
}
|
||||
|
||||
err = h.Sessions.Write(r.Context(), sessionID, sessionData, sessionLifetime)
|
||||
encryptedSessionData, err := sessionData.Encrypt(h.Crypter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypting session data: %w", err)
|
||||
}
|
||||
|
||||
err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing session to store: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) getAccessTokenFromSession(r *http.Request, sessionID string) (jwt.Token, error) {
|
||||
encryptedSession, err := h.Sessions.Read(r.Context(), sessionID)
|
||||
if err != nil {
|
||||
// Session not found; ignoring
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
sessionData, err := encryptedSession.Decrypt(h.Crypter)
|
||||
if err != nil {
|
||||
// Can't decrypt, likely not our session; ignoring
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
accessToken, err := jwt.Parse([]byte(sessionData.OAuth2Token.AccessToken))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing session access token: %w", err)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package session
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/nais/wonderwall/pkg/cryptutil"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
@@ -11,19 +10,17 @@ import (
|
||||
type memorySessionStore struct {
|
||||
lock sync.Mutex
|
||||
sessions map[string]*EncryptedData
|
||||
crypter cryptutil.Crypter
|
||||
}
|
||||
|
||||
var _ Store = &memorySessionStore{}
|
||||
|
||||
func NewMemory(crypter cryptutil.Crypter) Store {
|
||||
func NewMemory() Store {
|
||||
return &memorySessionStore{
|
||||
sessions: make(map[string]*EncryptedData),
|
||||
crypter: crypter,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *memorySessionStore) Read(_ context.Context, key string) (*Data, error) {
|
||||
func (s *memorySessionStore) Read(_ context.Context, key string) (*EncryptedData, error) {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
@@ -32,24 +29,14 @@ func (s *memorySessionStore) Read(_ context.Context, key string) (*Data, error)
|
||||
return nil, fmt.Errorf("no such session: %s", key)
|
||||
}
|
||||
|
||||
decrypted, err := data.Decrypt(s.crypter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypting session data: %w", err)
|
||||
}
|
||||
|
||||
return decrypted, nil
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *memorySessionStore) Write(_ context.Context, key string, value *Data, expiration time.Duration) error {
|
||||
func (s *memorySessionStore) Write(_ context.Context, key string, value *EncryptedData, expiration time.Duration) error {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
encrypted, err := value.Encrypt(s.crypter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypting session data: %w", err)
|
||||
}
|
||||
|
||||
s.sessions[key] = encrypted
|
||||
s.sessions[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -26,13 +26,20 @@ func TestMemory(t *testing.T) {
|
||||
IDTokenSerialized: "idtoken",
|
||||
}
|
||||
|
||||
sess := session.NewMemory(crypter)
|
||||
err = sess.Write(context.Background(), "key", data, time.Minute)
|
||||
encryptedData, err := data.Encrypt(crypter)
|
||||
assert.NoError(t, err)
|
||||
|
||||
sess := session.NewMemory()
|
||||
err = sess.Write(context.Background(), "key", encryptedData, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := sess.Read(context.Background(), "key")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, data, result)
|
||||
assert.Equal(t, encryptedData, result)
|
||||
|
||||
decrypted, err := result.Decrypt(crypter)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, data, decrypted)
|
||||
|
||||
err = sess.Delete(context.Background(), "key")
|
||||
|
||||
|
||||
@@ -2,28 +2,24 @@ package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/nais/wonderwall/pkg/cryptutil"
|
||||
"github.com/nais/wonderwall/pkg/metrics"
|
||||
"time"
|
||||
)
|
||||
|
||||
type redisSessionStore struct {
|
||||
client redis.Cmdable
|
||||
crypter cryptutil.Crypter
|
||||
}
|
||||
|
||||
var _ Store = &redisSessionStore{}
|
||||
|
||||
func NewRedis(client redis.Cmdable, crypter cryptutil.Crypter) Store {
|
||||
func NewRedis(client redis.Cmdable) Store {
|
||||
return &redisSessionStore{
|
||||
client: client,
|
||||
crypter: crypter,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *redisSessionStore) Read(ctx context.Context, key string) (*Data, error) {
|
||||
func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedData, error) {
|
||||
encryptedData := &EncryptedData{}
|
||||
err := metrics.ObserveRedisLatency("Read", func() error {
|
||||
var err error
|
||||
@@ -35,22 +31,12 @@ func (s *redisSessionStore) Read(ctx context.Context, key string) (*Data, error)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, err := encryptedData.Decrypt(s.crypter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypting session data: %w", err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
return encryptedData, nil
|
||||
}
|
||||
|
||||
func (s *redisSessionStore) Write(ctx context.Context, key string, value *Data, expiration time.Duration) error {
|
||||
encryptedData, err := value.Encrypt(s.crypter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *redisSessionStore) Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error {
|
||||
return metrics.ObserveRedisLatency("Write", func() error {
|
||||
status := s.client.Set(ctx, key, encryptedData, expiration)
|
||||
status := s.client.Set(ctx, key, value, expiration)
|
||||
return status.Err()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -30,18 +30,25 @@ func TestRedis(t *testing.T) {
|
||||
IDTokenSerialized: "idtoken",
|
||||
}
|
||||
|
||||
encryptedData, err := data.Encrypt(crypter)
|
||||
assert.NoError(t, err)
|
||||
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Network: "tcp",
|
||||
Addr: "127.0.0.1:6379",
|
||||
})
|
||||
|
||||
sess := session.NewRedis(client, crypter)
|
||||
err = sess.Write(context.Background(), "key", data, time.Minute)
|
||||
sess := session.NewRedis(client)
|
||||
err = sess.Write(context.Background(), "key", encryptedData, time.Minute)
|
||||
assert.NoError(t, err)
|
||||
|
||||
result, err := sess.Read(context.Background(), "key")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, data, result)
|
||||
assert.Equal(t, encryptedData, result)
|
||||
|
||||
decrypted, err := result.Decrypt(crypter)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, data, decrypted)
|
||||
|
||||
err = sess.Delete(context.Background(), "key")
|
||||
|
||||
|
||||
@@ -12,8 +12,8 @@ import (
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
Write(ctx context.Context, key string, value *Data, expiration time.Duration) error
|
||||
Read(ctx context.Context, key string) (*Data, error)
|
||||
Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error
|
||||
Read(ctx context.Context, key string) (*EncryptedData, error)
|
||||
Delete(ctx context.Context, keys ...string) error
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user