refactor: deduplicate crypto operations for sessions

This commit is contained in:
Trong Huu Nguyen
2021-09-30 18:27:53 +02:00
parent 8f9cb671c6
commit cc8ba980ca
9 changed files with 72 additions and 56 deletions

View File

@@ -79,10 +79,10 @@ func run() error {
return fmt.Errorf("connecting to redis: %w", err)
}
sessionStore = session.NewRedis(redisClient, crypt)
sessionStore = session.NewRedis(redisClient)
log.Infof("Using Redis as session backing store")
} else {
sessionStore = session.NewMemory(crypt)
sessionStore = session.NewMemory()
log.Warnf("Redis not configured, using in-memory session backing store; not suitable for multi-pod deployments!")
}

View File

@@ -372,22 +372,20 @@ func (h *Handler) FrontChannelLogout(w http.ResponseWriter, r *http.Request) {
sessionID := h.localSessionID(sid)
sess, err := h.Sessions.Read(r.Context(), sessionID)
if err != nil {
// From here on, check that 'iss' from request matches data found in access token.
accessToken, err := h.getAccessTokenFromSession(r, sessionID)
if accessToken == nil {
// Can't remove session because it doesn't exist. Maybe it was garbage collected.
// We regard this as a redundant logout and return 200 OK.
return
}
// From here on, check that 'iss' from request matches data found in access token.
tok, err := jwt.Parse([]byte(sess.OAuth2Token.AccessToken))
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
err = jwt.Validate(tok, jwt.WithClaimValue("iss", iss))
err = jwt.Validate(accessToken, jwt.WithClaimValue("iss", iss))
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return

View File

@@ -87,7 +87,7 @@ func handler(cfg config.IDPorten) *router.Handler {
}
crypter := cryptutil.New(encryptionKey)
sessionStore := session.NewMemory(crypter)
sessionStore := session.NewMemory()
handler, err := router.NewHandler(cfg, crypter, jwkSet, sessionStore, "")
if err != nil {

View File

@@ -25,11 +25,16 @@ func (h *Handler) getSessionFromCookie(r *http.Request) (*session.Data, error) {
return nil, fmt.Errorf("no session cookie: %w", err)
}
sessionData, err := h.Sessions.Read(r.Context(), sessionID)
encryptedSessionData, err := h.Sessions.Read(r.Context(), sessionID)
if err != nil {
return nil, fmt.Errorf("reading session from store: %w", err)
}
sessionData, err := encryptedSessionData.Decrypt(h.Crypter)
if err != nil {
return nil, fmt.Errorf("decrypting session data: %w", err)
}
return sessionData, nil
}
@@ -69,10 +74,36 @@ func (h *Handler) createSession(w http.ResponseWriter, r *http.Request, external
IDTokenSerialized: idToken.Raw,
}
err = h.Sessions.Write(r.Context(), sessionID, sessionData, sessionLifetime)
encryptedSessionData, err := sessionData.Encrypt(h.Crypter)
if err != nil {
return fmt.Errorf("encrypting session data: %w", err)
}
err = h.Sessions.Write(r.Context(), sessionID, encryptedSessionData, sessionLifetime)
if err != nil {
return fmt.Errorf("writing session to store: %w", err)
}
return nil
}
func (h *Handler) getAccessTokenFromSession(r *http.Request, sessionID string) (jwt.Token, error) {
encryptedSession, err := h.Sessions.Read(r.Context(), sessionID)
if err != nil {
// Session not found; ignoring
return nil, nil
}
sessionData, err := encryptedSession.Decrypt(h.Crypter)
if err != nil {
// Can't decrypt, likely not our session; ignoring
return nil, nil
}
accessToken, err := jwt.Parse([]byte(sessionData.OAuth2Token.AccessToken))
if err != nil {
return nil, fmt.Errorf("parsing session access token: %w", err)
}
return accessToken, nil
}

View File

@@ -3,7 +3,6 @@ package session
import (
"context"
"fmt"
"github.com/nais/wonderwall/pkg/cryptutil"
"sync"
"time"
)
@@ -11,19 +10,17 @@ import (
type memorySessionStore struct {
lock sync.Mutex
sessions map[string]*EncryptedData
crypter cryptutil.Crypter
}
var _ Store = &memorySessionStore{}
func NewMemory(crypter cryptutil.Crypter) Store {
func NewMemory() Store {
return &memorySessionStore{
sessions: make(map[string]*EncryptedData),
crypter: crypter,
}
}
func (s *memorySessionStore) Read(_ context.Context, key string) (*Data, error) {
func (s *memorySessionStore) Read(_ context.Context, key string) (*EncryptedData, error) {
s.lock.Lock()
defer s.lock.Unlock()
@@ -32,24 +29,14 @@ func (s *memorySessionStore) Read(_ context.Context, key string) (*Data, error)
return nil, fmt.Errorf("no such session: %s", key)
}
decrypted, err := data.Decrypt(s.crypter)
if err != nil {
return nil, fmt.Errorf("decrypting session data: %w", err)
}
return decrypted, nil
return data, nil
}
func (s *memorySessionStore) Write(_ context.Context, key string, value *Data, expiration time.Duration) error {
func (s *memorySessionStore) Write(_ context.Context, key string, value *EncryptedData, expiration time.Duration) error {
s.lock.Lock()
defer s.lock.Unlock()
encrypted, err := value.Encrypt(s.crypter)
if err != nil {
return fmt.Errorf("encrypting session data: %w", err)
}
s.sessions[key] = encrypted
s.sessions[key] = value
return nil
}

View File

@@ -26,13 +26,20 @@ func TestMemory(t *testing.T) {
IDTokenSerialized: "idtoken",
}
sess := session.NewMemory(crypter)
err = sess.Write(context.Background(), "key", data, time.Minute)
encryptedData, err := data.Encrypt(crypter)
assert.NoError(t, err)
sess := session.NewMemory()
err = sess.Write(context.Background(), "key", encryptedData, time.Minute)
assert.NoError(t, err)
result, err := sess.Read(context.Background(), "key")
assert.NoError(t, err)
assert.Equal(t, data, result)
assert.Equal(t, encryptedData, result)
decrypted, err := result.Decrypt(crypter)
assert.NoError(t, err)
assert.Equal(t, data, decrypted)
err = sess.Delete(context.Background(), "key")

View File

@@ -2,28 +2,24 @@ package session
import (
"context"
"fmt"
"github.com/go-redis/redis/v8"
"github.com/nais/wonderwall/pkg/cryptutil"
"github.com/nais/wonderwall/pkg/metrics"
"time"
)
type redisSessionStore struct {
client redis.Cmdable
crypter cryptutil.Crypter
}
var _ Store = &redisSessionStore{}
func NewRedis(client redis.Cmdable, crypter cryptutil.Crypter) Store {
func NewRedis(client redis.Cmdable) Store {
return &redisSessionStore{
client: client,
crypter: crypter,
}
}
func (s *redisSessionStore) Read(ctx context.Context, key string) (*Data, error) {
func (s *redisSessionStore) Read(ctx context.Context, key string) (*EncryptedData, error) {
encryptedData := &EncryptedData{}
err := metrics.ObserveRedisLatency("Read", func() error {
var err error
@@ -35,22 +31,12 @@ func (s *redisSessionStore) Read(ctx context.Context, key string) (*Data, error)
return nil, err
}
data, err := encryptedData.Decrypt(s.crypter)
if err != nil {
return nil, fmt.Errorf("decrypting session data: %w", err)
}
return data, nil
return encryptedData, nil
}
func (s *redisSessionStore) Write(ctx context.Context, key string, value *Data, expiration time.Duration) error {
encryptedData, err := value.Encrypt(s.crypter)
if err != nil {
return err
}
func (s *redisSessionStore) Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error {
return metrics.ObserveRedisLatency("Write", func() error {
status := s.client.Set(ctx, key, encryptedData, expiration)
status := s.client.Set(ctx, key, value, expiration)
return status.Err()
})
}

View File

@@ -30,18 +30,25 @@ func TestRedis(t *testing.T) {
IDTokenSerialized: "idtoken",
}
encryptedData, err := data.Encrypt(crypter)
assert.NoError(t, err)
client := redis.NewClient(&redis.Options{
Network: "tcp",
Addr: "127.0.0.1:6379",
})
sess := session.NewRedis(client, crypter)
err = sess.Write(context.Background(), "key", data, time.Minute)
sess := session.NewRedis(client)
err = sess.Write(context.Background(), "key", encryptedData, time.Minute)
assert.NoError(t, err)
result, err := sess.Read(context.Background(), "key")
assert.NoError(t, err)
assert.Equal(t, data, result)
assert.Equal(t, encryptedData, result)
decrypted, err := result.Decrypt(crypter)
assert.NoError(t, err)
assert.Equal(t, data, decrypted)
err = sess.Delete(context.Background(), "key")

View File

@@ -12,8 +12,8 @@ import (
)
type Store interface {
Write(ctx context.Context, key string, value *Data, expiration time.Duration) error
Read(ctx context.Context, key string) (*Data, error)
Write(ctx context.Context, key string, value *EncryptedData, expiration time.Duration) error
Read(ctx context.Context, key string) (*EncryptedData, error)
Delete(ctx context.Context, keys ...string) error
}